Halide:如何避免不必要的断言

Halide: How to avoid unwanted assertions

在 Halide 的管道开发过程中, 我想避免对缓冲区布局进行不必要的检查。 我知道我可以使用 'no_asserts' 目标功能关闭大多数断言。

但是,我生成了以下简单代码:

#define LUT_SIZE 17     /* Size in each dimension of the 4D LUT */

class ApplyLut : public Halide::Generator<ApplyLut> {
public:
    // We declare the Inputs to the Halide pipeline as public
    // member variables. They'll appear in the signature of our generated
    // function in the same order as we declare them.
  Input <  Buffer<uint8_t>> Lut              { "Lut"            , 1};  // LUT to apply
  Input <  Buffer<int>> indexToLut           { "indexToLut"     , 1};  // Precalculated mapping of uint8_t to LUT index
  Input <  Buffer<uint8_t >> inputImageLine  { "inputImageLine" , 1};  // Input line
  Output<  Buffer<uint8_t >> outputImageLine { "outputImageLine", 1};  // Output line
  void generate();
};

HALIDE_REGISTER_GENERATOR(ApplyLut, outputImageLine)

void ApplyLut::generate()
{
  Var x("x");

  outputImageLine(x) = Lut(clamp(indexToLut(inputImageLine(x)), 0, LUT_SIZE));

  inputImageLine .dim(0).set_min(0);         // Input image sample index
  inputImageLine .dim(0).set_stride(1);         // Input image sample index
  outputImageLine.dim(0).set_bounds(0, inputImageLine.dim(0).extent()); // Output line matches input line
  outputImageLine.dim(0).set_stride(   inputImageLine.dim(0).stride()); // Output line matches input line
  Lut            .dim(0).set_bounds(0, LUT_SIZE);          //iccLut[...]: , limited number of values
  Lut            .dim(0).set_stride(1);                    //iccLut[...]: , limited number of values
  indexToLut     .dim(0).set_bounds(0, 256);               //chan4_offset[...]: value index: 256 values
  indexToLut     .dim(0).set_stride(1);                    //chan4_offset[...]: value index: 256 values
}

除其他外,我在生成过程中使用了目标特征 'no_assert'(如输出中所示)。 然后我得到以下输出代码:

module name=applyIccProfile, target=x86-64-windows-disable_llvm_loop_opt-mingw-no_asserts-no_bounds_query-no_runtime-sse41 {
  func applyIccProfile(Lut, indexToLut, inputImageLine, outputImageLine) {
    assert((reinterpret(outputImageLine.buffer) != (uint64)0), halide_error_buffer_argument_is_null("outputImageLine"))
    assert((reinterpret(inputImageLine.buffer) != (uint64)0), halide_error_buffer_argument_is_null("inputImageLine"))
    assert((reinterpret(indexToLut.buffer) != (uint64)0), halide_error_buffer_argument_is_null("indexToLut"))
    assert((reinterpret(Lut.buffer) != (uint64)0), halide_error_buffer_argument_is_null("Lut"))
    let Lut = _halide_buffer_get_host(Lut.buffer)
    let Lut.min.0 = _halide_buffer_get_min(Lut.buffer, 0)
    let Lut.extent.0 = _halide_buffer_get_extent(Lut.buffer, 0)
    let Lut.stride.0 = _halide_buffer_get_stride(Lut.buffer, 0)
    let indexToLut = _halide_buffer_get_host(indexToLut.buffer)
    let indexToLut.min.0 = _halide_buffer_get_min(indexToLut.buffer, 0)
    let indexToLut.extent.0 = _halide_buffer_get_extent(indexToLut.buffer, 0)
    let indexToLut.stride.0 = _halide_buffer_get_stride(indexToLut.buffer, 0)
    let inputImageLine = _halide_buffer_get_host(inputImageLine.buffer)
    let inputImageLine.min.0 = _halide_buffer_get_min(inputImageLine.buffer, 0)
    let inputImageLine.extent.0 = _halide_buffer_get_extent(inputImageLine.buffer, 0)
    let inputImageLine.stride.0 = _halide_buffer_get_stride(inputImageLine.buffer, 0)
    let outputImageLine = _halide_buffer_get_host(outputImageLine.buffer)
    let outputImageLine.min.0 = _halide_buffer_get_min(outputImageLine.buffer, 0)
    let outputImageLine.extent.0 = _halide_buffer_get_extent(outputImageLine.buffer, 0)
    let outputImageLine.stride.0 = _halide_buffer_get_stride(outputImageLine.buffer, 0)

    assert((Lut.stride.0 == 1), 0)
    assert((Lut.min.0 == 0), 0)
    assert((Lut.extent.0 == 17), 0)
    assert((indexToLut.stride.0 == 1), 0)
    assert((indexToLut.min.0 == 0), 0)
    assert((indexToLut.extent.0 == 256), 0)
    assert((inputImageLine.stride.0 == 1), 0)
    assert((inputImageLine.min.0 == 0), 0)
    assert((outputImageLine.stride.0 == 1), 0)
    assert((outputImageLine.min.0 == 0), 0)
    assert((outputImageLine.extent.0 == inputImageLine.extent.0), 0)
    produce outputImageLine {
      for (outputImageLine.s0.x, 0, inputImageLine.extent.0) {
        outputImageLine[outputImageLine.s0.x] = Lut[max(min(indexToLut[int32(inputImageLine[outputImageLine.s0.x])], 17), 0)]
      }
    }
  }
}

在生成的输出中存在许多断言 检查所提供缓冲区的尺寸。

我知道这些断言每次调用都会执行一次 'only'。
但是,考虑到我想关闭这些断言的调用次数, 因为执行开销。

所以问题是:

虽然在 no_asserts 打开时断言仍然出现在 Halide IR 中,但任何剩余的断言都会在最终降低到 LLVM IR 时被剥离。它们只存在于 Halide IR 中,因为它们让 Halide 简化器知道在代码中的那个点之后可以假定某些东西为真,但它们编译为无操作。

断言消失后,LLVM 将通过死代码消除不必要的赋值。我会检查生成的程序集而不是 Halide IR,以确保所有这些检查都已完成。