Я реализую функцию радиального базиса в Halide, и хотя она у меня успешно работает, она довольно медленная. Для каждого пикселя я вычисляю расстояние, затем беру взвешенную сумму этого расстояния для получения результата. Для перебора весов я использую RDom (как показано ниже). В этой реализации каждое вычисление пикселя требует перезагрузки всех многих (3000+) весов, отсюда и низкая скорость.
Мой вопрос заключается в том, как в этом случае воспользоваться преимуществами функции планирования Halide. Я хочу загрузить некоторые веса, вычислить частичные взвешенные суммы для подмножества пикселей, загрузить следующий набор весов и продолжить до завершения. Это сохраняет локальность для каждой меньшей группы весов, и это именно то, для чего создан Halide. К сожалению, я не нашел ничего для этой конкретной проблемы. RDom, кажется, находится на более низком уровне абстракции, чем примитивы планирования, поэтому неясно, как это запланировать.
Приветствуются любые альтернативные предложения по реализации взвешенной суммы в Halide. Не нужно делать это с RDom, я просто не знаю другого способа.
Func rbf_ctrl_pts("rbf_ctrl_pts");
// Initialization with all zero
rbf_ctrl_pts(x,y,c) = cast<float>(0);
// Index to iterate with
RDom idx(0,num_ctrl_pts);
// Loop code
// Subtract the vectors
Expr red_sub = (*in_func)(x,y,0) - (*ctrl_pts_h)(0,idx);
Expr green_sub = (*in_func)(x,y,1) - (*ctrl_pts_h)(1,idx);
Expr blue_sub = (*in_func)(x,y,2) - (*ctrl_pts_h)(2,idx);
// Take the L2 norm to get the distance
Expr dist = sqrt( red_sub*red_sub +
green_sub*green_sub +
blue_sub*blue_sub );
// Update persistant loop variables
rbf_ctrl_pts(x,y,c) = select( c == 0, rbf_ctrl_pts(x,y,c) +
( (*weights_h)(0,idx) * dist),
c == 1, rbf_ctrl_pts(x,y,c) +
( (*weights_h)(1,idx) * dist),
rbf_ctrl_pts(x,y,c) +
( (*weights_h)(2,idx) * dist));