尝试使用 RcppArmadillo 编写 setdiff() 函数会出现编译错误

Trying to write a setdiff() function using RcppArmadillo gives compilation error

我正在尝试使用 RcppArmadillo 在 C++ 中编写 R 的 setdiff() 函数的一种模拟。我的粗略做法:

  // [[Rcpp::export]]
  arma::uvec my_setdiff(arma::uvec x, arma::uvec y){
  // Coefficientes of unsigned integer vector y form a subset of the coefficients of unsigned integer vector x.
  // Returns set difference between the coefficients of x and those of y
  int n2 = y.n_elem;
  uword q1;
  for (int j=0 ; j<n2 ; j++){
    q1 = find(x==y[j]);
    x.shed_row(q1);
  }
  return x;
  }

编译时失败。错误内容为:

fnsauxarma.cpp:622:29: error: no matching function for call to ‘arma::Col<double>::shed_row(const arma::mtOp<unsigned int, arma::mtOp<unsigned int, arma::Col<double>, arma::op_rel_eq>,     arma::op_find>)’

我真的不知道发生了什么,任何帮助或意见将不胜感激。

问题是 arma::find return 是 uvec,并且不知道如何隐式转换为 arma::uword,正如@mtall 所指出的.您可以使用模板化的 arma::conv_to<T>::from() 函数来帮助编译器。此外,我还包含了 my_setdiff 的另一个版本 return 是 Rcpp::NumericVector 因为虽然第一个版本 return 是正确的值,但它在技术上是 matrix (即它有尺寸),我假设您希望它尽可能与 R 的 setdiff 兼容。这是通过使用 R_NilValueRcpp::attr 成员函数将 return 向量的 dim 属性设置为 NULL 来实现的。


#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::uvec my_setdiff(arma::uvec& x, const arma::uvec& y){

  for (size_t j = 0; j < y.n_elem; j++) {
    arma::uword q1 = arma::conv_to<arma::uword>::from(arma::find(x == y[j]));
    x.shed_row(q1);
  }
  return x;
}

// [[Rcpp::export]]
Rcpp::NumericVector my_setdiff2(arma::uvec& x, const arma::uvec& y){

  for (size_t j = 0; j < y.n_elem; j++) {
    arma::uword q1 = arma::conv_to<arma::uword>::from(arma::find(x == y[j]));
    x.shed_row(q1);
  }

  Rcpp::NumericVector x2 = Rcpp::wrap(x);
  x2.attr("dim") = R_NilValue;
  return x2;
}

/*** R
x <- 1:8
y <- 2:6

R> all.equal(setdiff(x,y), my_setdiff(x,y))
#[1] "Attributes: < target is NULL, current is list >" "target is numeric, current is matrix"           

R> all.equal(setdiff(x,y), my_setdiff2(x,y))
#[1] TRUE

R> setdiff(x,y)
#[1] 1 7 8

R> my_setdiff(x,y)
# [,1]
# [1,]    1
# [2,]    7
# [3,]    8

R> my_setdiff2(x,y)
#[1] 1 7 8

*/

编辑: 为了完整起见,这里是 setdiff 的一个比上面介绍的两个实现更健壮的版本:

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

// [[Rcpp::export]]
Rcpp::NumericVector arma_setdiff(arma::uvec& x, arma::uvec& y){

    x = arma::unique(x);
    y = arma::unique(y);

    for (size_t j = 0; j < y.n_elem; j++) {
        arma::uvec q1 = arma::find(x == y[j]);
        if (!q1.empty()) {
            x.shed_row(q1(0));
        }
    }

    Rcpp::NumericVector x2 = Rcpp::wrap(x);
    x2.attr("dim") = R_NilValue;
    return x2;
}

/*** R

x <- 1:10
y <- 2:8

R> all.equal(setdiff(x,y), arma_setdiff(x,y))
#[1] TRUE

X <- 1:6
Y <- c(2,2,3)

R> all.equal(setdiff(X,Y), arma_setdiff(X,Y))
#[1] TRUE
*/

如果您向以前的版本传递具有非唯一元素的向量,例如

R> my_setdiff2(X,Y)

error: conv_to(): given object doesn't have exactly one element

为了解决这个问题并更接近地反映 R 的 setdiff,我们只需使 xy 唯一。此外,我用 q1(0) 切换了 arma::conv_to<>::from(其中 q1 现在是 uvec 而不是 uword),因为 uvec只是一个 uword 的向量,显式转换似乎有点不雅。

我使用了 STL 中的 std::set_difference,从 arma::uvec 来回转换。

#include <RcppArmadillo.h>
#include <algorithm>

// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::uvec std_setdiff(arma::uvec& x, arma::uvec& y) {

  std::vector<int> a = arma::conv_to< std::vector<int> >::from(arma::sort(x));
  std::vector<int> b = arma::conv_to< std::vector<int> >::from(arma::sort(y));
  std::vector<int> out;

  std::set_difference(a.begin(), a.end(), b.begin(), b.end(),
                      std::inserter(out, out.end()));

  return arma::conv_to<arma::uvec>::from(out);
}

编辑: 我认为可能需要进行性能比较。当集合的相对大小顺序相反时,差异会变小。

a <- sample.int(350)
b <- sample.int(150)

microbenchmark::microbenchmark(std_setdiff(a, b), arma_setdiff(a, b))

> Unit: microseconds
>                expr    min      lq     mean median     uq     max neval cld
>   std_setdiff(a, b) 11.548 14.7545 17.29930 17.107 19.245  36.779   100  a 
>  arma_setdiff(a, b) 60.727 65.0040 71.77804 66.714 72.702 138.133   100   b

发问者可能已经得到了答案。但是,以下模板版本可能更通用。这相当于Matlab中的setdiff函数

如果P和Q是两个集合,那么它们的差由P - Q或Q - P给出。如果P = {1, 2, 3, 4}Q = {4, 5, 6},P - Q表示P的元素不是在 Q. 即上例中 P - Q = {1, 2, 3}.

/* setdiff(t1, t2) is similar to setdiff() function in MATLAB. It removes the common elements and
   gives the uncommon elements in the vectors t1 and t2. */


template <typename T>
T setdiff(T t1, T t2)
{
    int size_of_t1 = size(t1);
    int size_of_t2 = size(t2);

    T Intersection_Elements;
    uvec iA, iB;
    intersect(Intersection_Elements, iA, iB, t1, t2);

    for (int i = 0; i < size(iA); i++)
    {
        t1(iA(i)) = 0;
    }

    for (int i = 0; i < size(iB); i++)
    {
        t2(iB(i)) = 0;
    }

    T t1_t2_vec(size_of_t1 + size_of_t2);
    t1_t2_vec = join_vert(t1, t2);
    T DiffVec = nonzeros(t1_t2_vec);


    return DiffVec;
}

欢迎提出任何改进算法性能的建议。