如何让函数针对不同的 x 值进行不同的计算?

How to get the function to be computed differently for different values of x?


    Func support("support");

    Expr left_x = clamp(x, 0, left_buffer.width() / 4);
    RDom scan_left(0, left_buffer.width() / 4, 0, left_buffer.height());
    scan_left.where(scan_left.x != left_x && scan_left.y != y);
    support(x, y) = argmin(abs(output_x(left_x, y) - output_x(scan_left.x, scan_left.y)) + abs(output_y(left_x, y) - output_y(scan_left.x, scan_left.y)));

    Expr center_x = clamp(x, left_buffer.width() / 4, left_buffer.width() * 3/4);
    RDom scan_center(-left_buffer.width() / 4, left_buffer.width() / 2, 0, left_buffer.height());
    scan_center.where(scan_center.x != 0 && scan_center.y != 0);
    support(x, y) = argmin(abs(output_x(center_x, y) - output_x(center_x + scan_center.x, scan_center.y)) + abs(output_y(center_x, scan_center.y) - output_y(center_x + scan_center.x, scan_center.y)));

    Expr right_x = clamp(x, left_buffer.width() * 3/4, left_buffer.width());
    RDom scan_right(left_buffer.width() * 3/4, left_buffer.width() / 4, 0, left_buffer.height());
    scan_right.where(scan_right.x != right_x && scan_right.y != y);
    support(x, y) = argmin(abs(output_x(right_x, y) - output_x(scan_right.x, scan_right.y)) + abs(output_y(right_x, y) - output_y(scan_right.x, scan_right.y)));

    support.trace_stores();

    Realization r = support.realize(left_buffer.width(), left_buffer.height());

函数 "support" 应该根据 x 值进行不同的计算。 for x = [0, width/4] 根据第一个定义计算,for x = [width/4, width * 3/4] 根据第二个定义计算,for x = [width *3/4, width] 根据第三个定义计算。我认为针对这些更新定义设置边界条件,然后在整个缓冲区上实现就可以了。但是现在,以前的定义正在被覆盖。因为这似乎行不通,所以我会考虑做三个实现,但这似乎不够优雅,因为我们只谈论一个图像。是否有可能在一个实现中达到结果,或者我必须分成三个实现? 我也试过 RDoms:

Func support("support");
    support(x, y) = Tuple(i32(0), i32(0), f32(0));

    RDom left_x(0, left_buffer.width() / 4);
    RDom scan_left(0, left_buffer.width() / 4, 0, left_buffer.height());
    scan_left.where(scan_left.x != left_x && scan_left.y != y);
    support(left_x, y) = argmin(scan_left, abs(output_x(left_x, y) - output_x(scan_left.x, scan_left.y)) + abs(output_y(left_x, y) - output_y(scan_left.x, scan_left.y)));

    RDom center_x(left_buffer.width() / 4, left_buffer.width() / 2);
    RDom scan_center(-left_buffer.width() / 4, left_buffer.width() / 2, 0, left_buffer.height());
    scan_center.where(scan_center.x != 0 && scan_center.y != 0);
    support(center_x, y) = argmin(scan_center, abs(output_x(center_x, y) - \
    output_x(center_x + scan_center.x, scan_center.y)) + abs(output_y(center_x, scan_center.y) - \
    output_y(center_x + scan_center.x, scan_center.y)));

    RDom right_x(left_buffer.width() * 3/4, left_buffer.width() / 4);
    RDom scan_right(left_buffer.width() * 3/4, left_buffer.width() / 4, 0, left_buffer.height());
    scan_right.where(scan_right.x != right_x && scan_right.y != y);
    support(right_x, y) = argmin(scan_right, abs(output_x(right_x, y) - output_x(scan_right.x, scan_right.y)) + abs(output_y(right_x, y) - output_y(scan_right.x, scan_right.y)));

    support.compute_root();
    support.trace_stores();

    Realization r_left = support.realize(left_buffer.width(), left_buffer.height());

但是此代码在以下几行中给出了错误:

scan_left.where(scan_left.x != left_x && scan_left.y != y);
...
scan_right.where(scan_right.x != right_x && scan_right.y != y);

解决此问题的一个简单方法是使用 Halide 的 select 方法(给定的示例 here)。这样的事情应该有效:

Func support("support");

Expr left_x = clamp(x, 0, left_buffer.width() / 4);
RDom scan_left(0, left_buffer.width() / 4, 0, left_buffer.height());
scan_left.where(scan_left.x != left_x && scan_left.y != y);
Expr first = argmin(abs(output_x(left_x, y) - output_x(scan_left.x, scan_left.y)) + abs(output_y(left_x, y) - output_y(scan_left.x, scan_left.y)));

Expr center_x = clamp(x, left_buffer.width() / 4, left_buffer.width() * 3/4);
RDom scan_center(-left_buffer.width() / 4, left_buffer.width() / 2, 0, left_buffer.height());
scan_center.where(scan_center.x != 0 && scan_center.y != 0);
Expr second = argmin(abs(output_x(center_x, y) - output_x(center_x + scan_center.x, scan_center.y)) + abs(output_y(center_x, scan_center.y) - output_y(center_x + scan_center.x, scan_center.y)));

Expr right_x = clamp(x, left_buffer.width() * 3/4, left_buffer.width());
RDom scan_right(left_buffer.width() * 3/4, left_buffer.width() / 4, 0, left_buffer.height());
scan_right.where(scan_right.x != right_x && scan_right.y != y);
Expr third = argmin(abs(output_x(right_x, y) - output_x(scan_right.x, scan_right.y)) + abs(output_y(right_x, y) - output_y(scan_right.x, scan_right.y)));

int width = left_buffer.width();
# select based on x value
support(x, y) = select(x < width / 4, first, x < width * 3 / 4, second, third);

support.trace_stores();

Realization r = support.realize(left_buffer.width(), left_buffer.height());

`