在 Halide 中使用 split() 时可以避免计算相同的元素吗?

Can I avoid calculating same elements when using split() in Halide?

我对 Halide 语言中 split() 的行为有疑问。

当我使用 split() 时,当计算区域不是分割因子的倍数时,它会在边缘计算元素两次。例如,当计算区域为10,分裂因子为4时,Halide将计算元素[0,1,2,3],[4,5,6,7]和[6,7,8,9],如下所示下面 trace_stores() 的结果。

有没有办法在split()内循环的最后一步只计算元素[8,9]?

示例代码:

#include "Halide.h"
using namespace Halide;

#define INPUT_SIZE 10
int main(int argc, char** argv) {
    Func f("f");
    Var x("x");
    f(x) = x;

    Var xi("xi");
    f.split(x, x, xi, 4); 

    f.trace_stores();
    Image<int32_t> out = f.realize(INPUT_SIZE);
    return 0;
}

trace_stores() 结果:

Store f.0(0) = 0
Store f.0(1) = 1
Store f.0(2) = 2
Store f.0(3) = 3
Store f.0(4) = 4
Store f.0(5) = 5
Store f.0(6) = 6
Store f.0(7) = 7
Store f.0(6) = 6
Store f.0(7) = 7
Store f.0(8) = 8
Store f.0(9) = 9

这是可能的,但很难看。 Halide 通常假设它可以任意重新评估 Func 中的点,并且输入不会与输出混淆,因此重新计算边缘附近的一些值总是安全的。

这很重要,这是一个不好的迹象。可能还有其他方法可以实现您想要做的事情。

无论如何,解决方法是使用显式 RDoms 准确地告诉 Halide 迭代什么:

// No pure definition
f(x) = undef<int>(); 

// An update stage that does the vectorized part:
Expr w = (input.width()/4)*4;
RDom r(0, w);
f(r) = something;
f.update(0).vectorize(r, 4);

// An update stage that does the tail end:
RDom r2(input.width(), input.width() - w);
f(r2) = something;
f.update(1); // Don't vectorize the tail end