如何使用域计算 Halide 中 n 通道的 maximum/minimum?

How to calculate maximum/minimum over n-channels in Halide using domains?

我目前正在试用 Halide,尝试在图像的所有通道上计算 maximum/minimum。我想为任意图像实现这一点,其中通道数量仅在运行时才知道。

我成功得到以下解决方案:

#include "Halide.h"
#include "halide_image_io.h"
using namespace Halide::Tools;

int main(int argc, char **argv) {
    Halide::Buffer<uint8_t> input = load_image("rgb.png");

    Halide::Var x, y;
    Halide::Func max_channels, min_channels;
    max_channels(x, y) = input(x, y, 0);
    min_channels(x, y) = input(x, y, 0);
    for (int i = 1; i < input.channels(); ++i)
    {
        max_channels(x, y) = max(max_channels(x, y), input(x, y, i));
        min_channels(x, y) = min(min_channels(x, y), input(x, y, i));
    }

    {
        Halide::Buffer<uint8_t> output =
            max_channels.realize({input.width(), input.height(), 1});
        save_image(output, "maximum.png");
    }
    {
        Halide::Buffer<uint8_t> output =
            min_channels.realize({input.width(), input.height(), 1});
        save_image(output, "minimum.png");
    }
    printf("Success!\n");

    return 0;
}

但是,我想知道是否可以在没有显式 for 循环的情况下实现这一点。按照 documentation,应该可以使用 Halide::RDom class 来实现。根据此处给出的示例,我认为以下内容应该有效。

#include "Halide.h"
#include "halide_image_io.h"
using namespace Halide::Tools;

int main(int argc, char **argv) {
    Halide::Buffer<uint8_t> input = load_image("rgb.png");

    Halide::Var x, y;
    Halide::Func max_channels, min_channels;
    Halide::RDom r(input);
    min_channels(x, y) = uint8_t{255};
    max_channels(x, y) = uint8_t{0};
    min_channels(r.x, r.y) = minimum(input(r.x, r.y, r.z));
    max_channels(r.x, r.y) = maximum(input(r.x, r.y, r.z));

    {
        Halide::Buffer<uint8_t> output =
            max_channels.realize({input.width(), input.height(), 1});
        save_image(output, "maximum.png");
    }
    {
        Halide::Buffer<uint8_t> output =
            min_channels.realize({input.width(), input.height(), 1});
        save_image(output, "minimum.png");
    }
    printf("Success!\n");

    return 0;
}

这可以编译,但不幸的是崩溃并显示以下错误消息:

terminate called after throwing an instance of 'Halide::CompileError'
what():  Error: In update definition 0 of Func "f1":
Tuple element 0 of update definition has type uint8, but pure definition has type int32

Aborted (core dumped)

这条消息对我来说没有任何意义,因为所有涉及的值都是 uint8_t,所以我不确定程序从哪里得到它的 int32_t

因此,我的问题是:使用域计算图像所有通道的 maximum/minimum 的正确方法是什么? 或者这是不可能的,并且必须使用 for 循环来代替?

下面是使用 maximumminimum 助手的方法:

#include <Halide.h>
#include <halide_image_io.h>
using namespace Halide;
using namespace Halide::Tools;

int main() {
  Buffer<uint8_t> input = load_image("rgb.png");

  Var x, y;
  RDom r_chan(0, input.dim(2).extent());

  Func min_img;
  Func max_img;

  min_img(x, y) = minimum(input(x, y, r_chan));
  max_img(x, y) = maximum(input(x, y, r_chan));

  Buffer<uint8_t> min_out = min_img.realize({input.width(), input.height()});
  Buffer<uint8_t> max_out = max_img.realize({input.width(), input.height()});

  save_image(min_out, "min_out.png");
  save_image(max_out, "max_out.png");
}

您只需创建一个跨越通道数(由 input.dim(2).extent() 给定)的缩减域。如果你不想使用助手,你可以改写(例如 min_img):

min_img(x, y) = cast<uint8_t>(255);
min_img(x, y) = min(input(x, y, r_chan), min_img(x, y));

缩减域的工作方式(没有其他调度指令)是它们插入一个最内层的循环,为缩减域中的每个点重复规则。因此,第二个示例中的循环如下所示:

for y:
  for x:
    min_img(...) = ...
for y:
  for x:
    for r_chan in [0, 2]:
      min_img(...) = ...

如果您担心优化器没有消除将所有内容初始化为 255 的第一个循环,那么您可以改写:

min_img(x, y) = undef(type_of<uint8_t>()); // skip the pure step
min_img(x, y) = min(
  input(x, y, r_chan),
  select(r_chan == 0, cast<uint8_t>(255), min_img(x, y))
);

在这种情况下,我希望 LLVM 剥离第一次迭代(至少)。