如何导出已经从 Halide::Generator 导出的 class?
How to derive a class that is already derived from Halide::Generator?
我想在Halide/C++中创建一个基于Halide::Generator的基本继承结构,以避免重复代码。
想法是拥有一个拥有纯虚函数的抽象基础生成器class。此外,每个派生 class 应该有一个特定的输入参数,在基础 class.
中不可用
在普通的 C++ 中,这非常简单,但由于 Halide 是一种在链接和编译之前“生成代码”的 DSL,事情可能会变得有点混乱。
我当前的 Halide 实现全部在一个文件中:
my_generators.cpp
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
class Base : public Halide::Generator<Base> {
public:
Input<Buffer<float>> input{"input", 2};
Output<Buffer<float>> output{"brighter", 2};
Var x, y;
virtual Func process(Func input) = 0;
virtual void generate() {
output = process(input);
output.vectorize(x, 16).parallel(y);
}
};
class DerivedGain : public Base {
public:
Input<float> gain{"gain"};
Func process (Func input) override{
Func result("result");
result(x,y) = input(x,y) * gain;
return result;
}
};
class DerivedOffset : public Base{
public:
Input<float> offset{"offset"};
Func process (Func input) override{
Func result("result");
result(x,y) = input(x,y) + offset;
return result;
}
};
HALIDE_REGISTER_GENERATOR(DerivedGain, derived_gain)
HALIDE_REGISTER_GENERATOR(DerivedOffset, derived_offset)
为了编译它,我使用了这个 CMakeLists 文件:
cmake_minimum_required(VERSION 3.16)
project(HalideExample)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED YES)
set(CMAKE_CXX_EXTENSIONS NO)
find_package(Halide REQUIRED)
add_executable(my_generators my_generators.cpp)
target_include_directories(my_generators PUBLIC ${HALIDE_ROOT}/include)
target_link_libraries(my_generators PRIVATE Halide::Generator)
add_halide_library(derived_gain FROM my_generators)
add_halide_library(derived_offset FROM my_generators)
我一直在使用预建版本Halide-13.0.1-x86-64-linux here
但是在编译过程中,它启动了一个错误,提示 class Base
正在被实例化(我不需要发生):
In file included from <path_to_project>/my_generators.cpp:2:
<path_to_halide>/include/Halide.h: In instantiation of ‘static std::unique_ptr<_Tp> Halide::Generator<T>::create(const Halide::GeneratorContext&) [with T = Base]’:
<path_to_halide>/include/Halide.h:26640:14: required from ‘static std::unique_ptr<_Tp> Halide::Generator<T>::create(const Halide::GeneratorContext&, const string&, const string&) [with T = Base; std::string = std::__cxx11::basic_string<char>]’
<path_to_project>/my_generators.cpp:53:1: required from here
<path_to_halide>/include/Halide.h:26631:37: error: invalid new-expression of abstract class type ‘Base’
26631 | auto g = std::unique_ptr<T>(new T());
| ^~~~~~~
<path_to_project>/my_generators.cpp:7:7: note: because the following virtual functions are pure within ‘Base’:
7 | class Base : public Halide::Generator<Base> {
| ^~~~
<path_to_project>/my_generators.cpp:15:18: note: ‘virtual Halide::NamesInterface::Func Base::process(Halide::NamesInterface::Func)’
15 | virtual Func process(Func input) = 0;
如果我不使用 virtual
函数而是在 Base class 中实现它,如下所示:
class Base : public Halide::Generator<Base> {
public:
Input<Buffer<float>> input{"input", 2};
Output<Buffer<float>> output{"brighter", 2};
Var x, y;
// Func process(Func input);
Func process (Func input){
Func result("result");
result(x,y) = input(x,y);
return result;
}
virtual void generate() {
output = process(input);
output.vectorize(x, 16).parallel(y);
}
};
然后一切都可以编译,但是带有生成代码的目标文件和头文件具有错误的函数签名(值得注意的是缺少 gain/offset 参数):
derived_gain.h:
int derived_gain(struct halide_buffer_t *_input_buffer, struct halide_buffer_t *_result_buffer);
derived_offset.h:
int derived_offset(struct halide_buffer_t *_input_buffer, struct halide_buffer_t *_result_buffer);
因此,我想知道我在 class 定义中引入了哪个错误以及如何解决它。
您可以将基础 class 变成模板:
template<class T>
class Base : public Halide::Generator<T> {
然后重新导出 Input
和 Output
名称...(我不是 C++ 大师,无法理解为什么这是必要的):
// In class Base:
template <typename T2>
using Input = typename Halide::Generator<T>::template Input<T2>;
template <typename T2>
using Output = typename Halide::Generator<T>::template Output<T2>;
那么剩下的改动就是:
class DerivedGain : public Base<DerivedGain> { ... };
class DerivedOffset : public Base<DerivedOffset> { ... };
这似乎对我有用。
此外,您的 CMakeLists.txt 中可能不需要这一行(我不需要):
target_include_directories(my_generators PUBLIC ${HALIDE_ROOT}/include)
我们的包没有设置 HALIDE_ROOT
,链接到 Halide::Generator
已经正确设置了包含路径。
我想在Halide/C++中创建一个基于Halide::Generator的基本继承结构,以避免重复代码。
想法是拥有一个拥有纯虚函数的抽象基础生成器class。此外,每个派生 class 应该有一个特定的输入参数,在基础 class.
中不可用在普通的 C++ 中,这非常简单,但由于 Halide 是一种在链接和编译之前“生成代码”的 DSL,事情可能会变得有点混乱。
我当前的 Halide 实现全部在一个文件中:
my_generators.cpp
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
class Base : public Halide::Generator<Base> {
public:
Input<Buffer<float>> input{"input", 2};
Output<Buffer<float>> output{"brighter", 2};
Var x, y;
virtual Func process(Func input) = 0;
virtual void generate() {
output = process(input);
output.vectorize(x, 16).parallel(y);
}
};
class DerivedGain : public Base {
public:
Input<float> gain{"gain"};
Func process (Func input) override{
Func result("result");
result(x,y) = input(x,y) * gain;
return result;
}
};
class DerivedOffset : public Base{
public:
Input<float> offset{"offset"};
Func process (Func input) override{
Func result("result");
result(x,y) = input(x,y) + offset;
return result;
}
};
HALIDE_REGISTER_GENERATOR(DerivedGain, derived_gain)
HALIDE_REGISTER_GENERATOR(DerivedOffset, derived_offset)
为了编译它,我使用了这个 CMakeLists 文件:
cmake_minimum_required(VERSION 3.16)
project(HalideExample)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED YES)
set(CMAKE_CXX_EXTENSIONS NO)
find_package(Halide REQUIRED)
add_executable(my_generators my_generators.cpp)
target_include_directories(my_generators PUBLIC ${HALIDE_ROOT}/include)
target_link_libraries(my_generators PRIVATE Halide::Generator)
add_halide_library(derived_gain FROM my_generators)
add_halide_library(derived_offset FROM my_generators)
我一直在使用预建版本Halide-13.0.1-x86-64-linux here
但是在编译过程中,它启动了一个错误,提示 class Base
正在被实例化(我不需要发生):
In file included from <path_to_project>/my_generators.cpp:2:
<path_to_halide>/include/Halide.h: In instantiation of ‘static std::unique_ptr<_Tp> Halide::Generator<T>::create(const Halide::GeneratorContext&) [with T = Base]’:
<path_to_halide>/include/Halide.h:26640:14: required from ‘static std::unique_ptr<_Tp> Halide::Generator<T>::create(const Halide::GeneratorContext&, const string&, const string&) [with T = Base; std::string = std::__cxx11::basic_string<char>]’
<path_to_project>/my_generators.cpp:53:1: required from here
<path_to_halide>/include/Halide.h:26631:37: error: invalid new-expression of abstract class type ‘Base’
26631 | auto g = std::unique_ptr<T>(new T());
| ^~~~~~~
<path_to_project>/my_generators.cpp:7:7: note: because the following virtual functions are pure within ‘Base’:
7 | class Base : public Halide::Generator<Base> {
| ^~~~
<path_to_project>/my_generators.cpp:15:18: note: ‘virtual Halide::NamesInterface::Func Base::process(Halide::NamesInterface::Func)’
15 | virtual Func process(Func input) = 0;
如果我不使用 virtual
函数而是在 Base class 中实现它,如下所示:
class Base : public Halide::Generator<Base> {
public:
Input<Buffer<float>> input{"input", 2};
Output<Buffer<float>> output{"brighter", 2};
Var x, y;
// Func process(Func input);
Func process (Func input){
Func result("result");
result(x,y) = input(x,y);
return result;
}
virtual void generate() {
output = process(input);
output.vectorize(x, 16).parallel(y);
}
};
然后一切都可以编译,但是带有生成代码的目标文件和头文件具有错误的函数签名(值得注意的是缺少 gain/offset 参数):
derived_gain.h:
int derived_gain(struct halide_buffer_t *_input_buffer, struct halide_buffer_t *_result_buffer);
derived_offset.h:
int derived_offset(struct halide_buffer_t *_input_buffer, struct halide_buffer_t *_result_buffer);
因此,我想知道我在 class 定义中引入了哪个错误以及如何解决它。
您可以将基础 class 变成模板:
template<class T>
class Base : public Halide::Generator<T> {
然后重新导出 Input
和 Output
名称...(我不是 C++ 大师,无法理解为什么这是必要的):
// In class Base:
template <typename T2>
using Input = typename Halide::Generator<T>::template Input<T2>;
template <typename T2>
using Output = typename Halide::Generator<T>::template Output<T2>;
那么剩下的改动就是:
class DerivedGain : public Base<DerivedGain> { ... };
class DerivedOffset : public Base<DerivedOffset> { ... };
这似乎对我有用。
此外,您的 CMakeLists.txt 中可能不需要这一行(我不需要):
target_include_directories(my_generators PUBLIC ${HALIDE_ROOT}/include)
我们的包没有设置 HALIDE_ROOT
,链接到 Halide::Generator
已经正确设置了包含路径。