检测数组中大于阈值的值,并使用推力将结果存储在二进制 (1/0) 数组中

Detect values greater than threshold in an array and store result in a binary(1/0) array using thrust

给定一个输入数组和一个阈值,我需要创建一个输出二进制数组,其中 1 表示大于阈值的值,0 表示小于阈值的值。我需要使用推力

我下面的尝试解决了问题,但看起来很笨拙。 如何一步完成。我的目标是在最短的计算时间内完成。

#include <thrust/replace.h>
#include <thrust/execution_policy.h>
#include <thrust/fill.h>
#include <thrust/device_vector.h>

int main(int argc, char * argv[])
{
int threshold=1;
thrust::device_vector<int> S(6);
S[0] = 1;
S[1] = 2;
S[2] = 3;
S[3] = 4;
S[4] = 5;
S[5] = 6;

// fill vector with zeros
thrust::device_vector<int> A(6);
thrust::fill(thrust::device, A.begin(), A.end(), 0);

// detect indices with values greater than zero
thrust::device_vector<int> indices(6);
thrust::device_vector<int>::iterator end = thrust::copy_if(thrust::make_counting_iterator(0),thrust::make_counting_iterator(6),S.begin(),indices.begin(),                                                              thrust::placeholders::_1 > threshold);
int size = end-indices.begin();
indices.resize(size);

// use permutation iterator along with indices above to change to ones

thrust::replace(thrust::device,thrust::make_permutation_iterator(A.begin(), indices.begin()), thrust::make_permutation_iterator(A.begin(), indices.end()), 0, 1);

for (int i=0;i<6;i++)
{
std::cout << "A["<<i<<"]=" << A[i] << std::endl;
}

return 0;
}

索引检测部分取自

只需使用自定义比较函子调用 thrust::transform 即可实现所需的功能。这是上述方法的一个例子。

#include <thrust/execution_policy.h>
#include <thrust/device_vector.h>
#include <thrust/transform.h>    

template<class T>
struct thresher
{
    T _thresh;
    thresher(T thresh) : _thresh(thresh) { }

    __host__ __device__ int operator()(T &x) const
    {
        return int(x > _thresh);
    }
};

int main(int argc, char * argv[])
{
    int threshold = 1;
    thrust::device_vector<int> S(6);
    S[0] = 1;
    S[1] = 2;
    S[2] = 3;
    S[3] = 4;
    S[4] = 5;
    S[5] = 6;

    thrust::device_vector<int> A(6);
    thrust::transform(S.begin(), S.end(), A.begin(), thresher<int>(threshold));

    for (int i=0;i<6;i++)
    {
        std::cout << "A["<<i<<"]=" << A[i] << std::endl;
    }

    return 0;
}