使用 RcppArmadillo 用向量填充上三角矩阵(包括对角线)
filling an upper triangular matrix (including diagonal) with a vector using RcppArmadillo
我正在学习 Rcpp
软件包的功能,之前没有使用 C++
的经验。我试过:
#include <RcppArmadillo.h>
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::mat VtoMatCpp(int n,
arma::vec x) {
arma::mat V = arma::eye<arma::mat>(n,n) ;
V.elem(find(trimatu(V))) = x;
return(V);
}
当我在 R 中使用 sourceCpp('fun.cpp')
然后尝试 VtoMatCpp(2,1:3)
得到 Error: Mat::elem(): size mismatch
。似乎 trimatu
函数没有选择对角线的索引。
您收到错误消息是因为您的 find
调用实际上是在查找非零元素(在本例中为对角线元素)。这导致您的 VtoMatCpp(2,1:3)
调用只有 2 个元素,而 3 个元素自然太大而无法放入其中。
这与我的问题 here 有点相似,我实际上想排除对角线元素。不幸的是,我现在能想到的最好的办法就是基本上复制 R 使用 upper.tri
的方式。这是 RcppArmadillo
.
的工作示例
library(inline)
src <- '
using namespace arma;
using namespace Rcpp;
vec x = as<vec>(x_);
int n = as<int>(n_);
mat V = eye<mat>(n,n);
// make empty matrices
mat Z(n,n,fill::zeros);
mat X(n,n,fill::zeros);
// fill matrices with integers
vec idx = linspace<mat>(1,n,n);
X.each_col() += idx;
Z.each_row() += trans(idx);
// assign upper triangular elements
// the >= allows inclusion of diagonal elements
V.elem(find(Z>=X)) = x;
return(wrap(V));
'
fun <- cxxfunction(signature(n_ = "integer", x_ = "vector"),
body=src, plugin="RcppArmadillo")
fun(2,1:3)
[,1] [,2]
[1,] 1 2
[2,] 0 3
与 base
R.
完全相同
fun2 <- function(a,b){
dm <- diag(2)
dm[upper.tri(dm, diag=TRUE)] <- 1:3
dm
}
fun2(2,1:3)
[,1] [,2]
[1,] 1 2
[2,] 0 3
运行 快速基准确实表明此实现比 base
R 更快。这里我将上面的 base
解决方案包装为 fun2
.
library(microbenchmark)
microbenchmark(fun(100, seq(5050)), fun2(100, seq(5050)))
Unit: microseconds
expr min lq mean median uq max neval
fun(100, seq(5050)) 117.823 154.106 241.2361 188.2575 242.0360 3392.611 100
fun2(100, seq(5050)) 545.042 592.988 736.6958 622.7405 650.7475 4057.011 100
这周我遇到了同样的问题,多亏了arma::trimatl_ind
功能,我能想出一个满意的解决方案。虽然@cdeterman 的回答已经解决了问题,但我相信我的解决方案更容易理解,也更简洁。而且我的判断是宝答案实现的功能有误,所以我也重写了。
library(inline)
fun2 <- function(n, x) {
dm <- matrix(0, nrow = n, ncol = n)
dm[upper.tri(dm, diag=TRUE)] <- x
dm
}
src2 <- '
using namespace Rcpp;
vec x = as<vec>(x_);
int n = as<int>(n_);
arma::mat out(n, n, arma::fill::zeros);
arma::uvec lw_idx = arma::trimatl_ind( arma::size(out) );
out.elem(lw_idx) = x;
return out;
'
fun3 <- cxxfunction(signature(n_ = "integer", x_ = "vector"),
body = src, plugin = "RcppArmadillo")
fun3(2, 1:3)
#> [,1] [,2]
#> [1,] 1 2
#> [2,] 0 3
最后,我也做了个时间对比。我的实施比之前提出的 Rcpp
解决方案稍快。然而,时间比较中最令人惊讶的信息是 R 中实现的函数的性能如何提高。
microbenchmark::microbenchmark(
"r" = fun2(100, seq(5050)),
"rcpp1" = fun(100, seq(5050)),
"rcpp2" = fun3(100, seq(5050))
)
#> Unit: microseconds
#> expr min lq mean median uq max neval
#> r 74.902 84.1845 311.6322 89.2775 149.7765 20276.213 100
#> rcpp1 67.045 109.1470 141.0202 116.8895 191.0715 229.312 100
#> rcpp2 54.575 106.3015 136.6472 114.3975 182.0125 231.395 100
下面是我的会话信息。
由 reprex package (v0.3.0)
于 2020-12-09 创建
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.0.3 (2020-10-10)
#> os Ubuntu 20.04.1 LTS
#> system x86_64, linux-gnu
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz America/New_York
#> date 2020-12-09
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib source
#> assertthat 0.2.1 2019-03-21 [2] CRAN (R 4.0.2)
#> callr 3.5.1 2020-10-13 [2] CRAN (R 4.0.3)
#> cli 2.2.0 2020-11-20 [2] CRAN (R 4.0.3)
#> crayon 1.3.4 2017-09-16 [2] CRAN (R 4.0.2)
#> desc 1.2.0 2018-05-01 [2] CRAN (R 4.0.2)
#> devtools 2.3.2 2020-09-18 [2] CRAN (R 4.0.3)
#> digest 0.6.27 2020-10-24 [2] CRAN (R 4.0.3)
#> ellipsis 0.3.1 2020-05-15 [2] CRAN (R 4.0.2)
#> evaluate 0.14 2019-05-28 [2] CRAN (R 4.0.2)
#> fansi 0.4.1 2020-01-08 [2] CRAN (R 4.0.2)
#> fs 1.5.0 2020-07-31 [2] CRAN (R 4.0.3)
#> glue 1.4.2 2020-08-27 [2] CRAN (R 4.0.2)
#> highr 0.8 2019-03-20 [2] CRAN (R 4.0.2)
#> htmltools 0.5.0 2020-06-16 [2] CRAN (R 4.0.2)
#> inline * 0.3.16 2020-09-06 [2] CRAN (R 4.0.3)
#> knitr 1.30 2020-09-22 [2] CRAN (R 4.0.3)
#> magrittr 2.0.1 2020-11-17 [2] CRAN (R 4.0.3)
#> memoise 1.1.0 2017-04-21 [2] CRAN (R 4.0.2)
#> microbenchmark 1.4-7 2019-09-24 [2] CRAN (R 4.0.2)
#> pkgbuild 1.1.0 2020-07-13 [2] CRAN (R 4.0.2)
#> pkgload 1.1.0 2020-05-29 [2] CRAN (R 4.0.2)
#> prettyunits 1.1.1 2020-01-24 [2] CRAN (R 4.0.2)
#> processx 3.4.5 2020-11-30 [2] CRAN (R 4.0.3)
#> ps 1.4.0 2020-10-07 [2] CRAN (R 4.0.3)
#> R6 2.5.0 2020-10-28 [2] CRAN (R 4.0.3)
#> Rcpp 1.0.5 2020-07-06 [2] CRAN (R 4.0.2)
#> RcppArmadillo 0.9.900.3.0 2020-09-03 [2] CRAN (R 4.0.2)
#> remotes 2.2.0 2020-07-21 [2] CRAN (R 4.0.3)
#> rlang 0.4.9 2020-11-26 [2] CRAN (R 4.0.3)
#> rmarkdown 2.4 2020-09-30 [2] CRAN (R 4.0.3)
#> rprojroot 2.0.2 2020-11-15 [2] CRAN (R 4.0.3)
#> sessioninfo 1.1.1 2018-11-05 [2] CRAN (R 4.0.2)
#> stringi 1.5.3 2020-09-09 [2] CRAN (R 4.0.3)
#> stringr 1.4.0 2019-02-10 [2] CRAN (R 4.0.2)
#> testthat 3.0.0 2020-10-31 [2] CRAN (R 4.0.3)
#> usethis 1.6.3 2020-09-17 [2] CRAN (R 4.0.3)
#> withr 2.3.0 2020-09-22 [2] CRAN (R 4.0.3)
#> xfun 0.18 2020-09-29 [2] CRAN (R 4.0.3)
#> yaml 2.2.1 2020-02-01 [2] CRAN (R 4.0.2)
#>
#> [1] /home/lcgodoy/R/x86_64-pc-linux-gnu-library/4.0
#> [2] /usr/local/lib/R/site-library
#> [3] /usr/lib/R/site-library
#> [4] /usr/lib/R/library
我正在学习 Rcpp
软件包的功能,之前没有使用 C++
的经验。我试过:
#include <RcppArmadillo.h>
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::mat VtoMatCpp(int n,
arma::vec x) {
arma::mat V = arma::eye<arma::mat>(n,n) ;
V.elem(find(trimatu(V))) = x;
return(V);
}
当我在 R 中使用 sourceCpp('fun.cpp')
然后尝试 VtoMatCpp(2,1:3)
得到 Error: Mat::elem(): size mismatch
。似乎 trimatu
函数没有选择对角线的索引。
您收到错误消息是因为您的 find
调用实际上是在查找非零元素(在本例中为对角线元素)。这导致您的 VtoMatCpp(2,1:3)
调用只有 2 个元素,而 3 个元素自然太大而无法放入其中。
这与我的问题 here 有点相似,我实际上想排除对角线元素。不幸的是,我现在能想到的最好的办法就是基本上复制 R 使用 upper.tri
的方式。这是 RcppArmadillo
.
library(inline)
src <- '
using namespace arma;
using namespace Rcpp;
vec x = as<vec>(x_);
int n = as<int>(n_);
mat V = eye<mat>(n,n);
// make empty matrices
mat Z(n,n,fill::zeros);
mat X(n,n,fill::zeros);
// fill matrices with integers
vec idx = linspace<mat>(1,n,n);
X.each_col() += idx;
Z.each_row() += trans(idx);
// assign upper triangular elements
// the >= allows inclusion of diagonal elements
V.elem(find(Z>=X)) = x;
return(wrap(V));
'
fun <- cxxfunction(signature(n_ = "integer", x_ = "vector"),
body=src, plugin="RcppArmadillo")
fun(2,1:3)
[,1] [,2]
[1,] 1 2
[2,] 0 3
与 base
R.
fun2 <- function(a,b){
dm <- diag(2)
dm[upper.tri(dm, diag=TRUE)] <- 1:3
dm
}
fun2(2,1:3)
[,1] [,2]
[1,] 1 2
[2,] 0 3
运行 快速基准确实表明此实现比 base
R 更快。这里我将上面的 base
解决方案包装为 fun2
.
library(microbenchmark)
microbenchmark(fun(100, seq(5050)), fun2(100, seq(5050)))
Unit: microseconds
expr min lq mean median uq max neval
fun(100, seq(5050)) 117.823 154.106 241.2361 188.2575 242.0360 3392.611 100
fun2(100, seq(5050)) 545.042 592.988 736.6958 622.7405 650.7475 4057.011 100
这周我遇到了同样的问题,多亏了arma::trimatl_ind
功能,我能想出一个满意的解决方案。虽然@cdeterman 的回答已经解决了问题,但我相信我的解决方案更容易理解,也更简洁。而且我的判断是宝答案实现的功能有误,所以我也重写了。
library(inline)
fun2 <- function(n, x) {
dm <- matrix(0, nrow = n, ncol = n)
dm[upper.tri(dm, diag=TRUE)] <- x
dm
}
src2 <- '
using namespace Rcpp;
vec x = as<vec>(x_);
int n = as<int>(n_);
arma::mat out(n, n, arma::fill::zeros);
arma::uvec lw_idx = arma::trimatl_ind( arma::size(out) );
out.elem(lw_idx) = x;
return out;
'
fun3 <- cxxfunction(signature(n_ = "integer", x_ = "vector"),
body = src, plugin = "RcppArmadillo")
fun3(2, 1:3)
#> [,1] [,2]
#> [1,] 1 2
#> [2,] 0 3
最后,我也做了个时间对比。我的实施比之前提出的 Rcpp
解决方案稍快。然而,时间比较中最令人惊讶的信息是 R 中实现的函数的性能如何提高。
microbenchmark::microbenchmark(
"r" = fun2(100, seq(5050)),
"rcpp1" = fun(100, seq(5050)),
"rcpp2" = fun3(100, seq(5050))
)
#> Unit: microseconds
#> expr min lq mean median uq max neval
#> r 74.902 84.1845 311.6322 89.2775 149.7765 20276.213 100
#> rcpp1 67.045 109.1470 141.0202 116.8895 191.0715 229.312 100
#> rcpp2 54.575 106.3015 136.6472 114.3975 182.0125 231.395 100
下面是我的会话信息。
由 reprex package (v0.3.0)
于 2020-12-09 创建devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.0.3 (2020-10-10)
#> os Ubuntu 20.04.1 LTS
#> system x86_64, linux-gnu
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz America/New_York
#> date 2020-12-09
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib source
#> assertthat 0.2.1 2019-03-21 [2] CRAN (R 4.0.2)
#> callr 3.5.1 2020-10-13 [2] CRAN (R 4.0.3)
#> cli 2.2.0 2020-11-20 [2] CRAN (R 4.0.3)
#> crayon 1.3.4 2017-09-16 [2] CRAN (R 4.0.2)
#> desc 1.2.0 2018-05-01 [2] CRAN (R 4.0.2)
#> devtools 2.3.2 2020-09-18 [2] CRAN (R 4.0.3)
#> digest 0.6.27 2020-10-24 [2] CRAN (R 4.0.3)
#> ellipsis 0.3.1 2020-05-15 [2] CRAN (R 4.0.2)
#> evaluate 0.14 2019-05-28 [2] CRAN (R 4.0.2)
#> fansi 0.4.1 2020-01-08 [2] CRAN (R 4.0.2)
#> fs 1.5.0 2020-07-31 [2] CRAN (R 4.0.3)
#> glue 1.4.2 2020-08-27 [2] CRAN (R 4.0.2)
#> highr 0.8 2019-03-20 [2] CRAN (R 4.0.2)
#> htmltools 0.5.0 2020-06-16 [2] CRAN (R 4.0.2)
#> inline * 0.3.16 2020-09-06 [2] CRAN (R 4.0.3)
#> knitr 1.30 2020-09-22 [2] CRAN (R 4.0.3)
#> magrittr 2.0.1 2020-11-17 [2] CRAN (R 4.0.3)
#> memoise 1.1.0 2017-04-21 [2] CRAN (R 4.0.2)
#> microbenchmark 1.4-7 2019-09-24 [2] CRAN (R 4.0.2)
#> pkgbuild 1.1.0 2020-07-13 [2] CRAN (R 4.0.2)
#> pkgload 1.1.0 2020-05-29 [2] CRAN (R 4.0.2)
#> prettyunits 1.1.1 2020-01-24 [2] CRAN (R 4.0.2)
#> processx 3.4.5 2020-11-30 [2] CRAN (R 4.0.3)
#> ps 1.4.0 2020-10-07 [2] CRAN (R 4.0.3)
#> R6 2.5.0 2020-10-28 [2] CRAN (R 4.0.3)
#> Rcpp 1.0.5 2020-07-06 [2] CRAN (R 4.0.2)
#> RcppArmadillo 0.9.900.3.0 2020-09-03 [2] CRAN (R 4.0.2)
#> remotes 2.2.0 2020-07-21 [2] CRAN (R 4.0.3)
#> rlang 0.4.9 2020-11-26 [2] CRAN (R 4.0.3)
#> rmarkdown 2.4 2020-09-30 [2] CRAN (R 4.0.3)
#> rprojroot 2.0.2 2020-11-15 [2] CRAN (R 4.0.3)
#> sessioninfo 1.1.1 2018-11-05 [2] CRAN (R 4.0.2)
#> stringi 1.5.3 2020-09-09 [2] CRAN (R 4.0.3)
#> stringr 1.4.0 2019-02-10 [2] CRAN (R 4.0.2)
#> testthat 3.0.0 2020-10-31 [2] CRAN (R 4.0.3)
#> usethis 1.6.3 2020-09-17 [2] CRAN (R 4.0.3)
#> withr 2.3.0 2020-09-22 [2] CRAN (R 4.0.3)
#> xfun 0.18 2020-09-29 [2] CRAN (R 4.0.3)
#> yaml 2.2.1 2020-02-01 [2] CRAN (R 4.0.2)
#>
#> [1] /home/lcgodoy/R/x86_64-pc-linux-gnu-library/4.0
#> [2] /usr/local/lib/R/site-library
#> [3] /usr/lib/R/site-library
#> [4] /usr/lib/R/library