Halide зависает во время нормализованной взаимной корреляции

Я пытаюсь реализовать нормализованную взаимную корреляцию в Halide.

Приведенный ниже код строится, и JIT-компиляция Halide не выдает никаких ошибок. Однако Halide, кажется, зависает после JIT-компиляции. Независимо от того, сколько trace_* вызовов я делаю для разных функций, всегда печатается только одна трассировка (на Func output):

Begin realization normxcorr.0(0, 2028, 0, 2028)
Produce normxcorr.0(0, 2028, 0, 2028)

Любой совет вообще будет полезен.

Этот алгоритм должен быть эквивалентен CV_TM_CCOEFF_NORMED в OpenCV и normxcorr2 в MATLAB:

void normxcorr( Halide::ImageParam input,
                Halide::ImageParam kernel,
                Halide::Param<pixel_t> kernel_mean,
                Halide::Param<pixel_t> kernel_var,
                Halide::Func& output )
{
    Halide::Var x, y;
    Halide::RDom rk( kernel );

    // reduction domain for cumulative sums
    Halide::RDom ri( 1, input.width() - kernel.width() - 1, 
                     1, input.height() - kernel.height() - 1 );

    Halide::Func input_32( "input32" ),
             bounded_input( "bounded_input"),
             kernel_32( "kernel32" ),
             knorm( "knorm" ),
             conv( "conv" ),
             normxcorr( "normxcorr_internal" ),
             sq_sum_x( "sq_sum_x" ),
             sq_sum_x_local( "sq_sum_x_local" ),
             sq_sum_y( "sq_sum_y" ),
             sq_sum_y_local( "sq_sum_y_local" ),
             sum_x( "sum_x" ),
             sum_x_local( "sum_x_local" ),
             sum_y( "sum_y" ),
             sum_y_local( "sum_y_local" ),
             win_var( "win_var" ),
             win_mean( "win_mean" );

    Halide::Expr ksize = kernel.width() * kernel.height();

    // accessing outside the input image always returns 0
    bounded_input( x, y ) = Halide::BoundaryConditions::constant_exterior( input, 0 )( x, y );

    // cast to 32-bit to make room for multiplication
    input_32( x, y ) = Halide::cast<int32_t>( bounded_input( x, y ) );
    kernel_32( x, y ) = Halide::cast<int32_t>( kernel( x, y ) );

    // cumulative sum along each row
    sum_x( x, y ) = input_32( x, y );
    sum_x( ri.x, ri.y ) += sum_x( ri.x - 1, ri.y );

    // sum of 1 x W strips
    // (W is the width of the kernel)
    sum_x_local( x, y ) = sum_x( x + kernel.width() - 1, y );
    sum_x_local( x, y ) -= sum_x( x - 1, y );

    // cumulative sums of the 1 x W strips along each column
    sum_y( x, y ) = sum_x_local( x, y );
    sum_y( ri.x, ri.y ) += sum_y( ri.x, ri.y - 1);

    // sums up H strips (as above) to get the sum of an H x W rectangle
    // (H is the height of the kernel)
    sum_y_local( x, y ) = sum_y( x, y + kernel.height() - 1 );
    sum_y_local( x, y ) -= sum_y( x, y - 1 );

    // same as above, just with squared image values
    sq_sum_x( x, y ) = input_32( x, y ) * input_32( x, y );
    sq_sum_x( ri.x, ri.y ) += sq_sum_x( ri.x - 1, ri.y );

    sq_sum_x_local( x, y ) = sq_sum_x( x + kernel.width() - 1, y );
    sq_sum_x_local( x, y ) -= sq_sum_x( x - 1, y );

    sq_sum_y( x, y ) = sq_sum_x_local( x, y );
    sq_sum_y( ri.x, ri.y ) += sq_sum_y( ri.x, ri.y - 1);

    sq_sum_y_local( x, y ) = sq_sum_y( x, y + kernel.height() - 1 );
    sq_sum_y_local( x, y ) -= sq_sum_y( x, y - 1 );

    // the mean value of each window
    win_mean( x, y ) = sum_y_local( x, y ) / ksize;

    // the variance of each window
    win_var( x, y ) =  sq_sum_y_local( x, y ) / ksize;
    win_var( x, y) -= win_mean( x, y ) * win_mean( x, y );

    // partially normalize the kernel
    // (we'll divide by std. dev. at the end)
    knorm( x, y ) = kernel_32( x, y ) - kernel_mean;

    // convolve kernel and the input
    conv( x, y ) = Halide::sum( knorm( rk.x, rk.y ) * input_32( x + rk.x, y + rk.y ) );

    // calculate normxcorr, except scaled to 0 to 254 (for an 8-bit image)
    normxcorr( x, y ) = conv( x, y ) * 127 / Halide::sqrt( kernel_var * win_var( x, y ) ) + 127;

    // after scaling pixel values, it's safe to cast down to 8-bit
    output( x, y ) = Halide::cast<pixel_t>( normxcorr( x, y ) );
}

person kevlar1818    schedule 25.06.2015    source источник


Ответы (1)


Я считаю, что проблема здесь просто в том, что вы не указали никакого расписания ни для одной из ваших функций, поэтому все встраивается, что приводит к огромному количеству избыточных вычислений промежуточных значений. Так что на самом деле это не технически зависание, а просто выполнение огромного объема работы для каждого пикселя и, следовательно, невозможность завершения за разумное время.

Для начала попробуйте сказать, что каждая функция должна быть compute_root (например, sum_x.compute_root();), проще всего в блоке в конце функции. Это должно происходить в гораздо более разумном темпе, должно печатать каждую функцию (начиная с входных данных) одну за другой, а не только normxcore.0, и должно завершиться.

На самом деле многие из ваших функций на самом деле являются просто точечными преобразованиями своих входных данных, поэтому эти входные данные можно оставить встроенными (не compute_root), что должно еще больше ускорить работу (особенно после того, как вы начнете распараллеливать и векторизовать некоторые этапы). На первый взгляд, [sq_]sum_{x,y}, вероятно, не следует встраивать, но все остальное, вероятно, можно оставить встроенным. knorm и input_32 — это жеребьевка, в зависимости от вашей цели и расписания.

Я добавил быструю рабочую версию с этим тривиальным расписанием и некоторыми другими небольшими исправлениями, здесь:

https://gist.github.com/d64823d754a732106a60

В моих тестах он работает на входе 2K ^ 2 менее чем за секунду без каких-либо особых изысков.

Кроме того, небольшой совет: компиляция вашего кода генератора с символами отладки (-g) должна освободить вас от указания строк имен во всех ваших объявлениях Func. Это было досадной проблемой в более ранних реализациях, но теперь мы можем сделать разумную работу, устанавливая эти имена непосредственно из имен исходных символов C++, если вы компилируете с включенными символами отладки.

person jrk    schedule 02.07.2015
comment
Спасибо! Одно небольшое замечание: даже с включенными символами отладки сообщения Halide не используют автоматически читаемую функцию/буфер/и т.д. имена; они отображаются как f0.0 и т. д. Я использую Halide v4.9. - person kevlar1818; 07.07.2015
comment
Странный. Какую платформу/тулчейн вы используете? Это работает довольно стабильно с последними выпусками для Linux и Mac; Я не уверен, работает ли он в Windows. (Я не уверен, что означает версия 4.9? Релизы названы по дате: github.com/halide /Halide/релизы.) - person jrk; 09.07.2015
comment
Я использую 64-битный Linux, компилирую с clang++. По этой ссылке (кстати, с плохой конечной точкой) есть два релиза Linux 64_trunk_gcc, датированные 2015_06_03. Единственная разница в названии — префикс 4.8 против 4.9. Я тоже не уверен, что это значит... - person kevlar1818; 13.07.2015
comment
Поздно отвечать на это, но 4.8 и 4.9 не являются версиями Halide, они относятся к версии цепочки инструментов gcc, используемой для сборки данного дистрибутива Halide. - person jrk; 22.08.2015