从整数值到类型的动态转换(C++11 模板元编程?)
Dynamic conversion from integer value to type (C++11 template metaprogramming?)
我正在尝试通过模板减少代码重复。我已经将大部分代码移至此助手 iterate_function_from_CSC_helper
,它现在是一个模板。然而,这个函数仍然重复了很多代码,只是为了调用一个模板的不同特化:
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (data_type == C_API_DTYPE_FLOAT32) {
if (col_ptr_type == C_API_DTYPE_INT32) {
return iterate_function_from_CSC_helper<float, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
return iterate_function_from_CSC_helper<float, int64_t>(col_ptr, indices, data, col_idx);
}
} else if (data_type == C_API_DTYPE_FLOAT64) {
if (col_ptr_type == C_API_DTYPE_INT32) {
return iterate_function_from_CSC_helper<double, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
return iterate_function_from_CSC_helper<double, int64_t>(col_ptr, indices, data, col_idx);
}
}
Log::Fatal("Unknown data type in CSC matrix");
return nullptr;
}
我想自动将运行时收到的整数 data_type
和 col_ptr_dtype
分别映射到类型 float/double 和 int32_t/int64_t 并调用模板那些。像这样:
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (<TTag<data_col>::invalid_type || TTag<col_ptr_type>::invalid_type) {
Log::Fatal("Unknown data type in CSC matrix");
return nullptr;
}
return iterate_function_from_CSC_helper<TTag<data_type>::type, TTag<col_ptr_type>::type>(col_ptr, indices, data, col_idx);
}
这可能吗?我假设通过一些元编程可以消除这一点。
我尝试了以下但无法使 dummy_IterateFunctionFromCSC
使用非常量输入(在运行时会出现这种情况):
#include <cstdint>
#include <stdio.h>
#include <iostream>
#include <type_traits>
#define C_API_DTYPE_FLOAT32 (0) /*!< \brief float32 (single precision float). */
#define C_API_DTYPE_FLOAT64 (1) /*!< \brief float64 (double precision float). */
#define C_API_DTYPE_INT32 (2) /*!< \brief int32. */
#define C_API_DTYPE_INT64 (3) /*!< \brief int64. */
struct TTagInvalidType {}; //! Meant for invalid types in TTag.
template <int C_API_DTYPE>
struct TTag {
using type = TTagInvalidType;
};
template<>
struct TTag<C_API_DTYPE_FLOAT32> {
using type = float;
};
template <>
struct TTag<C_API_DTYPE_FLOAT64> {
using type = double;
};
template <>
struct TTag<C_API_DTYPE_INT32> {
using type = int32_t;
};
template <>
struct TTag<C_API_DTYPE_INT64> {
using type = int64_t;
};
template <typename T>
void example_f () {
T x = 3.6;
std::cout << x << "\n";
}
template <>
void example_f<TTagInvalidType>() {
std::cout << "Abort!\n";
}
template<int x>
void dummy_IterateFunctionFromCSC() {
f<typename TTag<x>::type>();
}
int main() {
const int m = 2; // Doesn't work for non const integers (true at runtime)
dummy_IterateFunctionFromCSC<m>();
}
这可以编译,但只能使用常量 m,而不是使用从用户那里收到的整数。
这是不可能的,因为必须在编译时计算类型分派吗?还是有可能以及如何? :D
谢谢:)
将运行时值转换为编译时值确实需要像您一样if/switch。
您可以通过额外拆分来避免一些重复:
C++17 可能有助于减少 std::variant
的冗长,一些实用程序:
template <typename T> struct type_identity { using type = T; };
// type should be an enum
std::variant<type_identity<int32_t>, type_identity<int64_t>> to_compile_int_type(int type)
{
switch (type) {
case C_API_DTYPE_INT32: return type_identity<int32_t>{};
case C_API_DTYPE_INT64: return type_identity<int64_t>{};
default:
Log::Fatal("Unknown int data type");
throw "unknown type";
}
}
// type should be an enum
std::variant<type_identity<float>, type_identity<double>> to_compile_float_type(int type)
{
switch (type) {
case C_API_DTYPE_FLOAT32: return type_identity<float>{};
case C_API_DTYPE_FLOAT64: return type_identity<double>{};
default:
Log::Fatal("Unknown float data type");
throw "unknown type";
}
}
然后
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t ,
int col_idx)
{
CHECK(col_idx < ncol_ptr && col_idx >= 0);
std::visit(
[&](auto intvar, auto floatvar){
using inttype = typename decltype(intvar)::type;
using floattype = typename decltype(floatvar)::type;
return iterate_function_from_CSC_helper<floatype, inttype>(col_ptr, indices, data, col_idx);
},
to_compile_int_type(col_ptr_type),
to_compile_float_type(data_type)
);
}
我正在尝试通过模板减少代码重复。我已经将大部分代码移至此助手 iterate_function_from_CSC_helper
,它现在是一个模板。然而,这个函数仍然重复了很多代码,只是为了调用一个模板的不同特化:
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (data_type == C_API_DTYPE_FLOAT32) {
if (col_ptr_type == C_API_DTYPE_INT32) {
return iterate_function_from_CSC_helper<float, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
return iterate_function_from_CSC_helper<float, int64_t>(col_ptr, indices, data, col_idx);
}
} else if (data_type == C_API_DTYPE_FLOAT64) {
if (col_ptr_type == C_API_DTYPE_INT32) {
return iterate_function_from_CSC_helper<double, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
return iterate_function_from_CSC_helper<double, int64_t>(col_ptr, indices, data, col_idx);
}
}
Log::Fatal("Unknown data type in CSC matrix");
return nullptr;
}
我想自动将运行时收到的整数 data_type
和 col_ptr_dtype
分别映射到类型 float/double 和 int32_t/int64_t 并调用模板那些。像这样:
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (<TTag<data_col>::invalid_type || TTag<col_ptr_type>::invalid_type) {
Log::Fatal("Unknown data type in CSC matrix");
return nullptr;
}
return iterate_function_from_CSC_helper<TTag<data_type>::type, TTag<col_ptr_type>::type>(col_ptr, indices, data, col_idx);
}
这可能吗?我假设通过一些元编程可以消除这一点。
我尝试了以下但无法使 dummy_IterateFunctionFromCSC
使用非常量输入(在运行时会出现这种情况):
#include <cstdint>
#include <stdio.h>
#include <iostream>
#include <type_traits>
#define C_API_DTYPE_FLOAT32 (0) /*!< \brief float32 (single precision float). */
#define C_API_DTYPE_FLOAT64 (1) /*!< \brief float64 (double precision float). */
#define C_API_DTYPE_INT32 (2) /*!< \brief int32. */
#define C_API_DTYPE_INT64 (3) /*!< \brief int64. */
struct TTagInvalidType {}; //! Meant for invalid types in TTag.
template <int C_API_DTYPE>
struct TTag {
using type = TTagInvalidType;
};
template<>
struct TTag<C_API_DTYPE_FLOAT32> {
using type = float;
};
template <>
struct TTag<C_API_DTYPE_FLOAT64> {
using type = double;
};
template <>
struct TTag<C_API_DTYPE_INT32> {
using type = int32_t;
};
template <>
struct TTag<C_API_DTYPE_INT64> {
using type = int64_t;
};
template <typename T>
void example_f () {
T x = 3.6;
std::cout << x << "\n";
}
template <>
void example_f<TTagInvalidType>() {
std::cout << "Abort!\n";
}
template<int x>
void dummy_IterateFunctionFromCSC() {
f<typename TTag<x>::type>();
}
int main() {
const int m = 2; // Doesn't work for non const integers (true at runtime)
dummy_IterateFunctionFromCSC<m>();
}
这可以编译,但只能使用常量 m,而不是使用从用户那里收到的整数。
这是不可能的,因为必须在编译时计算类型分派吗?还是有可能以及如何? :D
谢谢:)
将运行时值转换为编译时值确实需要像您一样if/switch。
您可以通过额外拆分来避免一些重复:
C++17 可能有助于减少 std::variant
的冗长,一些实用程序:
template <typename T> struct type_identity { using type = T; };
// type should be an enum
std::variant<type_identity<int32_t>, type_identity<int64_t>> to_compile_int_type(int type)
{
switch (type) {
case C_API_DTYPE_INT32: return type_identity<int32_t>{};
case C_API_DTYPE_INT64: return type_identity<int64_t>{};
default:
Log::Fatal("Unknown int data type");
throw "unknown type";
}
}
// type should be an enum
std::variant<type_identity<float>, type_identity<double>> to_compile_float_type(int type)
{
switch (type) {
case C_API_DTYPE_FLOAT32: return type_identity<float>{};
case C_API_DTYPE_FLOAT64: return type_identity<double>{};
default:
Log::Fatal("Unknown float data type");
throw "unknown type";
}
}
然后
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t ,
int col_idx)
{
CHECK(col_idx < ncol_ptr && col_idx >= 0);
std::visit(
[&](auto intvar, auto floatvar){
using inttype = typename decltype(intvar)::type;
using floattype = typename decltype(floatvar)::type;
return iterate_function_from_CSC_helper<floatype, inttype>(col_ptr, indices, data, col_idx);
},
to_compile_int_type(col_ptr_type),
to_compile_float_type(data_type)
);
}