Return推力二元函数

Return thrust binary function

我正在尝试定义一个函数,它将 return 基于字符串内容的所需类型运算符。我试过这个,但它不起作用:

impl.cpp

template <typename T> thrust::binary_function<T,T,bool>
 get_filter_operator(const std::string &op)
    if (op == "!=")
        return thrust::not_equal_to<T>();
    else if (op == ">")
        return thrust::greater<T>();
    else if (op == "<")
        return thrust::less<T>();
    else if (op == ">=")
        return thrust::greater_equal<T>();
    else if (op == "<=")
        return thrust::less_equal<T>();
    else
    {
        return thrust::equal_to<T>();
    }

template thrust::binary_function<float,float,bool> get_filter_operator<float>(const std::string &);

impl.h

template <typename T> thrust::binary_function<T, T, bool> get_filter_operator(const std::string &op);

如何 return 指向任意函数的指针,如 thrust::not_equal_to<int>()thrust::equal_to<int>()?我找不到 return.

的正确类型

编辑

根据要求,编译器错误:

在 ‘thrust::binary_function get_filter_operator(const string&) [with T = float; std::string = std::basic_string<字符>]':

错误:无法将“thrust::equal_to()”从“thrust::equal_to”转换为“thrust::binary_function ' return thrust::equal_to()

更新

很抱歉之前没有提到这个:问题是我不能使用 std::function 因为它只能在主机代码上工作。我想使用 thrust 二元函数,这样我就可以在 GPU 和 CPU.

中使用它们

以下适合我。

#include <iostream>
#include <functional>

template<class T>
std::function<bool(T, T)> GetOperator(const std::string& op)
{
    if (op == "!=")
        return std::not_equal_to<T>();
    else if (op == ">")
        return std::greater<T>();
    else if (op == "<")
        return std::less<T>();
    else if (op == ">=")
        return std::greater_equal<T>();
    else if (op == "<=")
        return std::less_equal<T>();
    else
    {
        return std::equal_to<T>();
    }
}

int main()
{
    auto op = GetOperator<int>(">");
    std::cout << op(1, 2) << '\n';
    return 0;
}

How can I return a pointer to an arbitrary function like thrust::not_equal_to(), or thrust::equal_to()? I cant find the correct type to return

您尝试 return 的每件事都是两个参数的函数, 每个类型 T,即 return 和 bool。正确的 return 类型是

std::function<bool(T, T)>

如:

#include <thrust/functional.h>
#include <functional>
#include <string>

template<typename T>
std::function<bool(T, T)>
get_filter_operator(const std::string &op)
{
    if (op == "!=")
        return thrust::not_equal_to<T>();
    else if (op == ">")
        return thrust::greater<T>();
    else if (op == "<")
        return thrust::less<T>();
    else if (op == ">=")
        return thrust::greater_equal<T>();
    else if (op == "<=")
        return thrust::less_equal<T>();
    else
    {
        return thrust::equal_to<T>();
    }
}

#include <iostream>

using namespace std;

int main()
{
    auto relop = get_filter_operator<int>("!=");
    cout << boolalpha << relop(1,0) << endl;
    cout << boolalpha << relop(1,1) << endl;

    return 0;
}

现在,您可能希望向@MohamadElghawi 重申您的评论:

Yeah, I knew that worked, but the problem is that I'm trying to return a thrust::binary_function, not from std

这可能是您想要做的,但这样做是错误的 试图做和不可能做的事情。看看定义 template<typename A1, typename A2, typename R> struct thrust::binary_function<thrust/functional> 和相关文档中。注:

binary_function is an empty base class: it contains no member functions or member variables, but only type information

特别是thrust::binary_function<A1,A2,R>没有operator()。 它是不可调用的。它不能存储任何其他可调用对象(或 任何东西)。另见 equal_tonot_equal_to 的定义, 等在同一个文件中。 binary_function 不是其中任何一个的偶数基。 没有从它们中的任何一个转换为 binary_function.

也请注意:

binary_function is currently redundant with the C++ STL type std::binary_function. We reserve it here for potential additional functionality at a later date.

std::binary_function 自 C++11 起已弃用,并将在 C++17 中删除)。

thrust::binary_function<T,T,bool> 不是您要查找的内容。 std::function<bool(T, T)> 是。

std::function<bool(int, int)> f = thrust::greater<int>(); 

使 f 封装一个可调用对象,该对象是 thrust::greater<int>

以后

The problem with this is that it can only be used in host code doesnt it? The beauty of thrust binary functions is that they can be used both in the GPU and the CPU.

我想你的印象可能是,例如

std::function<bool(int, int)> f = thrust::greater<int>();  /*A*/

采用 thrust::greater<int> 并以某种方式 将其降级 std::function<bool(int, int)> 具有类似但更受限制的 ("std") 执行能力。

并非如此。 std::function<bool(int, int)> foo 只是一个容器 任何 bar 可以用两个隐式参数调用的东西 可转换为 int 和 return 可以隐式转换为 bool, 这样如果:

std::function<bool(int, int)> foo = bar; 

然后当您调用 foo(i,j) 时,您会 returned 结果,如 bool正在执行 bar(i,j)。不是执行任何不同方式的结果 来自 bar(i,j).

因此在上面的 /*A*/ 中,f 包含并调用的可调用对象是 推力二元函数;它是一个thrust::greater<int>()。该方法 由 foperator() is thrust::greater<int>::operator().

调用

这是一个程序:

#include <thrust/functional.h>
#include <functional>
#include <iostream>

using namespace std;

int main()
{
    auto thrust_greater_than_int = thrust::greater<int>();
    std::function<bool(int, int)> f = thrust_greater_than_int;
    cout << "f " 
        << (f.target<thrust::greater<int>>() ? "calls" : "does not call") 
        << " a thrust::greater<int>" << endl;
    cout << "f " 
        << (f.target<thrust::equal_to<int>>() ? "calls" : "does not call") 
        << " a thrust::equal_to<int>" << endl;
    cout << "f " 
        << (f.target<std::greater<int>>() ? "calls" : "does not call") 
        << " an std::greater<int>" << endl;
    cout << "f " 
        << (f.target<std::function<bool(int,int)>>() ? "calls" : "does not call") 
        << " an std::function<bool(int,int)>" << endl;
    return 0;
}

thrust::greater<int> 存储在 std::function<bool(int, int)> f 中 然后通知您:

f calls a thrust::greater<int>
f does not call a thrust::equal_to<int>
f does not call an std::greater<int>
f does not call an std::function<bool(int,int)>

即使没有确切的答案,我也会把我最终使用的东西放在这里,以防有人需要类似的东西。

在 .cuh 文件中

#include <cuda.h>
#include <cuda_runtime_api.h>

namespace BinaryFunction
{
    enum class ComparisonOperator
    {
      equal_to,
      not_equal_to,
      greater,
      less,
      greater_equal,
      less_equal
    };

    enum class BitwiseOperator
    {
      bit_and,
      bit_or
    };

    template<typename T>
    struct CompareFunction
    {
      __host__ __device__ T operator()(const T &lhs, const T &rhs, ComparisonOperator &op) const
      {
            switch (op)
            {
                case ComparisonOperator::equal_to:
                    return lhs==rhs;
                case ComparisonOperator::not_equal_to:
                    return lhs!=rhs;
                case ComparisonOperator::greater:
                    return lhs>rhs;
                case ComparisonOperator::less:
                    return lhs<rhs;
                case ComparisonOperator::greater_equal:
                    return lhs>=rhs;
                case ComparisonOperator::less_equal:
                    return lhs<=rhs;

            }
        }
    };

    template<typename T>
    struct BitwiseFunction
    {
      __host__ __device__ T operator()(const T &lhs, const T &rhs, BitwiseOperator &op) const
      {
        if (op==BitwiseOperator::bit_and)
          return lhs & rhs;
        else if (op==BitwiseOperator::bit_or)
          return lhs | rhs;
      }
    };
}

然后像这样使用它: 在 cpp 文件中:

BinaryFunction::ComparisonOperator comp_op = BinaryFunction::ComparisonOperator::equal_to;

BinaryFunction::CompareFunction<int> comp_func;

然后,在内核或普通函数中:

int value_a;
int value_b;
comp_func(value_a, value_b, comp_op)