使用元组累加器减少推力
Thrust reduce with tuple accumulator
我想在 thrust::tuple<double,double>
的 thrust::host_vector
上使用 thrust::reduce
。因为没有预定义的 thrust::plus<thrust::tuple<double,double>>
我自己编写了 thrust::reduce
的变体和四个参数。
因为我是一个好公民,所以我将 plus
的自定义版本放在我自己的命名空间中,在那里我没有定义主要模板并将其专门用于 thrust::tuple<T...>
.
#include <iostream>
#include <tuple>
#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include <thrust/tuple.h>
namespace thrust_ext {
namespace detail {
//
template <size_t ...I>
struct index_sequence {};
template <size_t N, size_t ...I>
struct make_index_sequence : public make_index_sequence<N - 1, N - 1, I...> {};
template <size_t ...I>
struct make_index_sequence<0, I...> : public index_sequence<I...> {};
template < typename... T, size_t... I >
__host__ __device__ thrust::tuple<T...> plus(thrust::tuple<T...> const &lhs,
thrust::tuple<T...> const &rhs,
index_sequence<I...>) {
return {thrust::get<I>(lhs) + thrust::get<I>(rhs) ...};
}
} // namespace detail
template < typename T >
struct plus;
template < typename... T >
struct plus < thrust::tuple<T...> > {
__host__ __device__ thrust::tuple<T...> operator()(thrust::tuple<T...> const &lhs,
thrust::tuple<T...> const &rhs) const {
return detail::plus(lhs,rhs,detail::make_index_sequence<sizeof...(T)>{});
}
};
} //namespace thrust_ext
int main() {
thrust::host_vector<thrust::tuple<double,double>> v(10, thrust::make_tuple(1.0,2.0));
auto r = thrust::reduce(v.begin(), v.end(),
thrust::make_tuple(0.0,0.0),
thrust_ext::plus<thrust::tuple<double,double>>{});
std::cout << thrust::get<0>(r) << ' ' << thrust::get<1>(r) << '\n';
}
但是,这不会编译。错误消息非常长,请参阅 this Gist。错误消息表明问题出在 thrust::reduce
的某些实现细节中。此外,如果我将 thrust::tuple
替换为 std::tuple
,它会按预期编译和运行。
我在 Clang 6 中使用 Thrust 1.8.1。
如您在错误消息中所见,thrust::tuple<double,double>
实际上是
thrust::tuple<double, double, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>
这是一个基于默认模板参数的 C++03 风格 "variadic template",这意味着 sizeof...(T)
将计算所有 null_type
并产生错误的大小(总是10).
您需要使用 thrust::tuple_size
来检索实际大小。
我想在 thrust::tuple<double,double>
的 thrust::host_vector
上使用 thrust::reduce
。因为没有预定义的 thrust::plus<thrust::tuple<double,double>>
我自己编写了 thrust::reduce
的变体和四个参数。
因为我是一个好公民,所以我将 plus
的自定义版本放在我自己的命名空间中,在那里我没有定义主要模板并将其专门用于 thrust::tuple<T...>
.
#include <iostream>
#include <tuple>
#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include <thrust/tuple.h>
namespace thrust_ext {
namespace detail {
//
template <size_t ...I>
struct index_sequence {};
template <size_t N, size_t ...I>
struct make_index_sequence : public make_index_sequence<N - 1, N - 1, I...> {};
template <size_t ...I>
struct make_index_sequence<0, I...> : public index_sequence<I...> {};
template < typename... T, size_t... I >
__host__ __device__ thrust::tuple<T...> plus(thrust::tuple<T...> const &lhs,
thrust::tuple<T...> const &rhs,
index_sequence<I...>) {
return {thrust::get<I>(lhs) + thrust::get<I>(rhs) ...};
}
} // namespace detail
template < typename T >
struct plus;
template < typename... T >
struct plus < thrust::tuple<T...> > {
__host__ __device__ thrust::tuple<T...> operator()(thrust::tuple<T...> const &lhs,
thrust::tuple<T...> const &rhs) const {
return detail::plus(lhs,rhs,detail::make_index_sequence<sizeof...(T)>{});
}
};
} //namespace thrust_ext
int main() {
thrust::host_vector<thrust::tuple<double,double>> v(10, thrust::make_tuple(1.0,2.0));
auto r = thrust::reduce(v.begin(), v.end(),
thrust::make_tuple(0.0,0.0),
thrust_ext::plus<thrust::tuple<double,double>>{});
std::cout << thrust::get<0>(r) << ' ' << thrust::get<1>(r) << '\n';
}
但是,这不会编译。错误消息非常长,请参阅 this Gist。错误消息表明问题出在 thrust::reduce
的某些实现细节中。此外,如果我将 thrust::tuple
替换为 std::tuple
,它会按预期编译和运行。
我在 Clang 6 中使用 Thrust 1.8.1。
如您在错误消息中所见,thrust::tuple<double,double>
实际上是
thrust::tuple<double, double, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>
这是一个基于默认模板参数的 C++03 风格 "variadic template",这意味着 sizeof...(T)
将计算所有 null_type
并产生错误的大小(总是10).
您需要使用 thrust::tuple_size
来检索实际大小。