RcppArmadillo:如何用另一个向量替换向量中的 NA

RcppArmadillo: How to replace NAs in a vector with another vector

设置

假设我有一个包含一些 NA 的 x 向量和一个包含这些 NA 的替换的 y 向量:

x <- c(NA, 2, NA, 4)
y <- c(80, 90)

我正在尝试创建一个重现此操作的 RcppArmadillo 函数:

x[is.na(x)] <- y

所以 x == c(80, 2, 90, 4).

我尝试过的事情

阅读 some documentation 后,我能够编写一个简短的函数,将 X 中的 NA 替换为零:

Rcpp::cppFunction(
    depends = 'RcppArmadillo',
    'arma::vec f(arma::vec x) {
        x.elem(find_nonfinite(x)).zeros();
        return(x);
     }'
)

实际上表现如下:

r$> f(x)                                                                                                                                                                                               
     [,1]
[1,]    0
[2,]    2
[3,]    0
[4,]    4

但是,我无法找到一种单行(或几行)方式来完成我想做的事情。我尝试过使用 x.replace(),但我永远无法正确匹配类型。我想另一种方法是遍历每个元素 y(j)x(i),如果 x(i) == NA_INTEGER 则执行 x(i) = y(j),但这听起来比它需要更多的行应该。

为了提高知名度,这里有一个基于 user2957945 上面评论的解决方案:

cppFunction('
    arma::vec f2(arma::vec& x, arma::vec& y) {
        x.elem(find_nonfinite(x)) = y;
        return(x);
    }',
    depends='RcppArmadillo'
)

您可以通过以下方式避免任何复制

# create the function
Rcpp::sourceCpp(code = '
    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    #include <cmath>
    
    // [[Rcpp::export(rng = false)]]
    void f(arma::vec &x, arma::vec const &y) {
        auto yi = y.begin();
        for(auto &xi : x)
           if(std::isnan(xi)){
               if(yi == y.end())
                   throw std::runtime_error("y is too short");
               xi = *yi++;
           }
     }')

# check the result
x <- c(NA, 2, NA, 4)
y <- c(80, 90)
.Internal(inspect(x)) 
#R> @563e025f7448 14 REALSXP g1c3 [MARK,REF(5)] (len=4, tl=0) nan,2,nan,4

f(x, y)  
x
#R> [1] 80  2 90  4
.Internal(inspect(x))
#R> @563e025f7448 14 REALSXP g1c3 [MARK,REF(6)] (len=4, tl=0) 80,2,90,4

如果您知道 yi != y.end() 为真,则可以取消检查。另一种方法是使用 for_each

Rcpp::sourceCpp(code = '
    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    #include <cmath>
    
    // [[Rcpp::export(rng = false)]]
    void f(arma::vec &x, arma::vec const &y) {
        auto yi = y.begin();
        x.for_each([&](double &xi){
            if(std::isnan(xi)){
               if(yi == y.end())
                   throw std::runtime_error("y is too short");
               xi = *yi++;
           }
        });
     }')

这里给出一个小基准

# create a small data set for a benchmark
set.seed(1)
n <- 1000L
x <- c(1, rnorm(n))
x[runif(n) < .2] <- NA
y <- rnorm(sum(is.na(x)))

bench::mark(
    R = { z <- x; z[1] <- z[1] + 0.; z[is.na(z)] <- y; z }, 
    user2957945 = { z <- x; z[1] <- z[1] + 0.; drop(f2(z, y)) },
    inplace = { z <- x; z[1] <- z[1] + 0.; f(z, y); drop(z) }, 
    `inplace no range check` = 
        { z <- x; z[1] <- z[1] + 0.; f_no_range_check(z, y); drop(z) })
#R> # A tibble: 4 x 13
#R>   expression                  min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result        memory             time                gc                   
#R>   <bch:expr>             <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>        <list>             <list>              <list>               
#R> 1 R                        4.03µs   4.66µs   157070.   16.66KB     47.1  9997     3     63.6ms <dbl [1,001]> <Rprofmem [4 × 3]> <bench_tm [10,000]> <tibble [10,000 × 3]>
#R> 2 user2957945              5.71µs   6.38µs   150675.   18.23KB     45.2  9997     3     66.3ms <dbl [1,001]> <Rprofmem [3 × 3]> <bench_tm [10,000]> <tibble [10,000 × 3]>
#R> 3 inplace                  2.74µs   3.09µs   316877.    7.87KB     63.4  9998     2     31.6ms <dbl [1,001]> <Rprofmem [1 × 3]> <bench_tm [10,000]> <tibble [10,000 × 3]>
#R> 4 inplace no range check   3.47µs   3.82µs   254082.   10.36KB     50.8  9998     2     39.4ms <dbl [1,001]> <Rprofmem [2 × 3]> <bench_tm [10,000]> <tibble [10,000 × 3]>

# the time that should be subtracted
bench::mark(`copy cost` = { z <- x; z[1] <- z[1] + 0. })
#R> # A tibble: 1 x 13
#R>   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result    memory             time                gc                   
#R>   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>    <list>             <list>              <list>               
#R> 1 copy cost    1.18µs    1.5µs   528254.    7.87KB     106.  9998     2     18.9ms <dbl [1]> <Rprofmem [1 × 3]> <bench_tm [10,000]> <tibble [10,000 × 3]>

这是函数定义

Rcpp::cppFunction('
    arma::vec f2(arma::vec x, arma::vec& y) {
        x.elem(find_nonfinite(x)) = y;
        return(x);
    }', depends='RcppArmadillo')
                 
Rcpp::sourceCpp(code = '
    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    #include <cmath>
    
    // [[Rcpp::export(rng = false)]]
    void f(arma::vec &x, arma::vec const &y) {
        auto yi = y.begin();
        x.for_each([&](double &xi){
            if(std::isnan(xi)){
               if(yi == y.end())
                   throw std::runtime_error("y is too short");
               xi = *yi++;
           }
        });
     }')

Rcpp::cppFunction('void f_no_range_check(arma::vec &x, arma::vec const &y) {
                       auto yi = y.begin();
                       x.for_each([&](double &xi){
                           if(std::isnan(xi)) xi = *yi++;
                       });
                   }', depends='RcppArmadillo')

两个就地版本之间的区别是Rcpp::export(rng = false)rng = true 有点开销。