推力减少和重载运算符-(const float3&, const float3&) 不会编译

Thrust reduction and overloaded operator-(const float3&, const float3&) won't compile

我在 vectorspace.cuh:

中重载运算符以在 float3(和类似结构)上有一个向量 space
// Boilerplate vector space over data type Pt
#pragma once

#include <type_traits>


// float3
__device__ __host__ float3 operator+=(float3& a, const float3& b) {
    a.x += b.x; a.y += b.y; a.z += b.z;
    return a;
}

__device__ __host__ float3 operator*=(float3& a, const float b) {
    a.x *= b; a.y *= b; a.z *= b;
    return a;
}

// float4
__device__ __host__ float4 operator+=(float4& a, const float4& b) {
    a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w;
    return a;
}

__device__ __host__ float4 operator*=(float4& a, const float b) {
    a.x *= b; a.y *= b; a.z *= b; a.w *= b;
    return a;
}


// Generalize += and *= to +, -=, -, *, /= and /
template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator+(const Pt& a, const Pt& b) {
    auto sum = a;
    sum += b;
    return sum;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator-=(Pt& a, const Pt& b) {
    a += -1*b;
    return a;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator-(const Pt& a, const Pt& b) {
    auto diff = a;
    diff -= b;
    return diff;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator-(const Pt& a) {
    return -1*a;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator*(const Pt& a, const float b) {
    auto prod = a;
    prod *= b;
    return prod;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator*(const float b, const Pt& a) {
    auto prod = a;
    prod *= b;
    return prod;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator/=(Pt& a, const float b) {
    a *= 1./b;
    return a;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator/(const Pt& a, const float b) {
    auto quot = a;
    quot /= b;
    return quot;
}

这些重载会破坏 thrust::reduce 的编译,这里有一个例子:

#include <thrust/reduce.h>
#include <thrust/execution_policy.h>

#include "vectorspace.cuh"


int main(int argc, char const *argv[]) {
    int n = 10;
    float3* d_arr;
    cudaMalloc(&d_arr, n*sizeof(float3));

    auto sum = thrust::reduce(thrust::device, d_arr, d_arr + n, float3 {0});

    return 0;
}

在 Ubuntu 16.04 上使用 nvcc -std=c++11 -arch=sm_52 这会导致 200 多行编译器错误:

$ nvcc -std=c++11 -arch=sm_52 sandbox/mean.cu 
sandbox/mean.cu(26): error: no operator "*" matches these operands
            operand types are: int * const thrust::zip_iterator<thrust::tuple<const float3 *, thrust::pointer<float3, thrust::system::cuda::detail::par_t, thrust::use_default, thrust::use_default>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>
          detected during:
            instantiation of "std::enable_if<<expression>, Pt>::type operator-=(Pt &, const Pt &) [with Pt=thrust::zip_iterator<thrust::tuple<const float3 *, thrust::pointer<float3, thrust::system::cuda::detail::par_t, thrust::use_default, thrust::use_default>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>]" 
(35): here
            instantiation of "std::enable_if<<expression>, Pt>::type operator-(const Pt &, const Pt &) [with Pt=thrust::zip_iterator<thrust::tuple<const float3 *, thrust::pointer<float3, thrust::system::cuda::detail::par_t, thrust::use_default, thrust::use_default>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>]" 

...

如何在不破坏 thrust 的情况下重载运算符?

(根据 OP 的编辑进行编辑。)

问题出在运算符重载的 'reach' 上:您不仅重载了您感兴趣的 类,还重载了 all 类 适合你的 enable_if 条件 - 这很放松。即使可以编译,这也已经是一个严重的错误。

更具体地说,推力使用算术运算,例如在 "zip iterators" 上(不要管它们是什么),并且可以理解,对此类迭代器的操作编译失败。

所以你必须:

  • 准确指定重载与哪个 类 相关(例如,在 enable_if 中使用 std::is_same 的析取),或
  • 使用 trait class:

    template<class T> struct needs_qivs_arithmetic_operators : public std::false_type {};
    
    template<> struct needs_qivs_arithmetic_operators<float3> : public std::true_type {};
    template<> struct needs_qivs_arithmetic_operators<float4> : public std::true_type {};
    /* ... etc. You can also add specializations elsewhere in the translation unit. */