nVidia 推力:device_ptr Const-Correctness
nVidia Thrust: device_ptr Const-Correctness
在我广泛使用 nVidia CUDA 的项目中,我有时会使用 Thrust 来做一些非常非常好的事情。 Reduce 是一种在该库中实现得特别好的算法,reduce 的一种用途是通过将每个非负元素除以归一化非负元素向量元素乘以所有元素的总和。
template <typename T>
void normalise(T const* const d_input, const unsigned int size, T* d_output)
{
const thrust::device_ptr<T> X = thrust::device_pointer_cast(const_cast<T*>(d_input));
T sum = thrust::reduce(X, X + size);
thrust::constant_iterator<T> denominator(sum);
thrust::device_ptr<T> Y = thrust::device_pointer_cast(d_output);
thrust::transform(X, X + size, denominator, Y, thrust::divides<T>());
}
(T
通常是 float
或 double
)
一般来说,我不想在我的整个代码库中依赖 Thrust,所以我尝试确保像上面示例这样的函数只接受原始 CUDA 设备指针。这意味着一旦它们被 NVCC 编译,我就可以 link 将它们静态地转换为其他代码而无需 NVCC。
然而,这段代码让我担心。我希望该函数是常量正确的,但我似乎找不到 thrust::device_pointer_cast(...)
的 const
版本 - 这样的事情是否存在?在这个版本的代码中,我使用了 const_cast
,所以我在函数签名中使用了 const
,这让我很难过。
附带说明一下,将 reduce 的结果复制到主机只是为了将其发送回设备以进行下一步感觉很奇怪。有更好的方法吗?
如果你想要 const-correctness,你需要在任何地方都是 const-correct。 input
是指向 const T
的指针,因此应该是 X
:
const thrust::device_ptr<const T> X = thrust::device_pointer_cast(d_input);
在我广泛使用 nVidia CUDA 的项目中,我有时会使用 Thrust 来做一些非常非常好的事情。 Reduce 是一种在该库中实现得特别好的算法,reduce 的一种用途是通过将每个非负元素除以归一化非负元素向量元素乘以所有元素的总和。
template <typename T>
void normalise(T const* const d_input, const unsigned int size, T* d_output)
{
const thrust::device_ptr<T> X = thrust::device_pointer_cast(const_cast<T*>(d_input));
T sum = thrust::reduce(X, X + size);
thrust::constant_iterator<T> denominator(sum);
thrust::device_ptr<T> Y = thrust::device_pointer_cast(d_output);
thrust::transform(X, X + size, denominator, Y, thrust::divides<T>());
}
(T
通常是 float
或 double
)
一般来说,我不想在我的整个代码库中依赖 Thrust,所以我尝试确保像上面示例这样的函数只接受原始 CUDA 设备指针。这意味着一旦它们被 NVCC 编译,我就可以 link 将它们静态地转换为其他代码而无需 NVCC。
然而,这段代码让我担心。我希望该函数是常量正确的,但我似乎找不到 thrust::device_pointer_cast(...)
的 const
版本 - 这样的事情是否存在?在这个版本的代码中,我使用了 const_cast
,所以我在函数签名中使用了 const
,这让我很难过。
附带说明一下,将 reduce 的结果复制到主机只是为了将其发送回设备以进行下一步感觉很奇怪。有更好的方法吗?
如果你想要 const-correctness,你需要在任何地方都是 const-correct。 input
是指向 const T
的指针,因此应该是 X
:
const thrust::device_ptr<const T> X = thrust::device_pointer_cast(d_input);