在RcppArmadillo中直接调用LAPACK例程
Calling LAPACK routine directly in RcppArmadillo
由于犰狳 (afaik) 没有三角求解器,我想使用 dtrtrs
. I have looked at the following two (first, second) 中可用的 LAPACK 三角求解器) SO 线程并将某些东西拼凑在一起,但它不起作用。
我使用 RStudio 创建了一个全新的包,同时还启用了 RcppArmadillo。我有一个头文件 header.h
:
#include <RcppArmadillo.h>
#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
#define arma_dtrtrs dtrtrs
#else
#define arma_dtrtrs DTRTRS
#endif
#endif
extern "C" {
void arma_fortran(arma_dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
double* A, int* LDA, double* B, int* LDB, int* INFO);
}
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb);
static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x);
这基本上是第一个链接问题的答案,还有一个包装函数和主要函数。函数的内容在 trisolve.cpp
中,如下所示:
#include "header.h"
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
int info = 0;
wrapper_dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
return info;
}
static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x) {
size_t rows = in_A.n_rows;
size_t cols = in_A.n_cols;
double *A = new double[rows*cols];
double *b = new double[in_b.size()];
//Lapack has column-major order
for(size_t col=0, D1_idx=0; col<cols; ++col)
{
for(size_t row = 0; row<rows; ++row)
{
// Lapack uses column major format
A[D1_idx++] = in_A(row, col);
}
b[col] = in_b(col);
}
for(size_t row = 0; row<rows; ++row)
{
b[row] = in_b(row);
}
int info = trtrs('U', 'N', 'N', cols, 1, A, rows, b, rows);
for(size_t col=0; col<cols; col++) {
out_x(col)=b[col];
}
delete[] A;
delete[] b;
return 0;
}
// [[Rcpp::export]]
arma::mat RtoRcpp(arma::mat A, arma::mat b) {
arma::uword n = A.n_rows;
arma::mat x = arma::mat(n, 1, arma::fill::zeros);
int info = trisolve(A, b, x);
return x;
}
我有(至少)两个问题:
- 尝试编译时,我从头文件中得到:
conflicting types for 'dtrtrs_'
。但是,我看不出输入有什么问题(这实际上是从第二个链接线程中复制的)。
- 不出所料,
wrapper_dtrtrts_
是不正确的。但是据我从 Armadillo 的 compiler_setup.hpp
中得知,arma_fortran
应该为我创建一个名为 wrapper_dtrtrs_
的函数。我应该在主 cpp
文件中使用什么名称?
犰狳已经有一个三角求解器。代码改编自 documentation:
mat A(5,5, fill::randu);
// ... make A triangular here ...
mat B(5,5, fill::randu);
// tell solve() to treat A as upper triangular
// and automatically enable fast mode
mat X = solve(trimatu(A), B);
根据文档,Armadillo 求解器似乎可以自动检测带状矩阵和对称正定矩阵。
Armadillo 已经使用 dtrtrs
解决三角问题。部分代码参考:
dtrtrs
在 lapack::trtrs
中被调用:https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/wrapper_lapack.hpp#L908
lapack::trtrs
在 auxlib::solve_tri
中被调用,带有一个很好的调试语句:https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/auxlib_meat.hpp#L3983
所以如果我们可以触发这个调试语句,我们可以确定 dtrtrs
确实被使用了:
#define ARMA_EXTRA_DEBUG
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
// [[Rcpp::export]]
void testTrisolve() {
arma::mat A = arma::randu<arma::mat>(5,5);
arma::mat B = arma::randu<arma::mat>(5,5);
arma::mat X1 = arma::solve(A, B);
arma::mat X3 = arma::solve(arma::trimatu(A), B);
}
/*** R
testTrisolve()
*/
这会产生很多调试消息,其中包括:
lapack::gesvx()
[...]
lapack::trtrs()
所以我们清楚地看到dtrtrs
用于三对角线的情况
关于你原来的问题:
- 冲突类型错误是 Aramdillo 已经使用
dtrtrs
,但签名略有不同(A
是 const
)的结果。
- Fortran 函数的 C 级名称取决于
ARMA_BLAS_UNDERSCORE
和 ARMA_USE_WRAPPER
的值。我不确定情况是否总是如此,但对我而言,前者已定义而后者未定义(c.f。config.hpp
),导致 dtrtrs_
作为名称。
确实,如果我在 Armadillo 使用它的地方添加一个 const
并将函数调用为 dtrtrs_
,您的代码编译时不会出现错误或警告(除了未使用的变量 ...) :
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
extern "C" {
void arma_fortran(dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
const double* A, int* LDA, double* B, int* LDB, int* INFO);
}
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
int info = 0;
dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
return info;
}
[...]
由于犰狳 (afaik) 没有三角求解器,我想使用 dtrtrs
. I have looked at the following two (first, second) 中可用的 LAPACK 三角求解器) SO 线程并将某些东西拼凑在一起,但它不起作用。
我使用 RStudio 创建了一个全新的包,同时还启用了 RcppArmadillo。我有一个头文件 header.h
:
#include <RcppArmadillo.h>
#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
#define arma_dtrtrs dtrtrs
#else
#define arma_dtrtrs DTRTRS
#endif
#endif
extern "C" {
void arma_fortran(arma_dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
double* A, int* LDA, double* B, int* LDB, int* INFO);
}
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb);
static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x);
这基本上是第一个链接问题的答案,还有一个包装函数和主要函数。函数的内容在 trisolve.cpp
中,如下所示:
#include "header.h"
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
int info = 0;
wrapper_dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
return info;
}
static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x) {
size_t rows = in_A.n_rows;
size_t cols = in_A.n_cols;
double *A = new double[rows*cols];
double *b = new double[in_b.size()];
//Lapack has column-major order
for(size_t col=0, D1_idx=0; col<cols; ++col)
{
for(size_t row = 0; row<rows; ++row)
{
// Lapack uses column major format
A[D1_idx++] = in_A(row, col);
}
b[col] = in_b(col);
}
for(size_t row = 0; row<rows; ++row)
{
b[row] = in_b(row);
}
int info = trtrs('U', 'N', 'N', cols, 1, A, rows, b, rows);
for(size_t col=0; col<cols; col++) {
out_x(col)=b[col];
}
delete[] A;
delete[] b;
return 0;
}
// [[Rcpp::export]]
arma::mat RtoRcpp(arma::mat A, arma::mat b) {
arma::uword n = A.n_rows;
arma::mat x = arma::mat(n, 1, arma::fill::zeros);
int info = trisolve(A, b, x);
return x;
}
我有(至少)两个问题:
- 尝试编译时,我从头文件中得到:
conflicting types for 'dtrtrs_'
。但是,我看不出输入有什么问题(这实际上是从第二个链接线程中复制的)。 - 不出所料,
wrapper_dtrtrts_
是不正确的。但是据我从 Armadillo 的compiler_setup.hpp
中得知,arma_fortran
应该为我创建一个名为wrapper_dtrtrs_
的函数。我应该在主cpp
文件中使用什么名称?
犰狳已经有一个三角求解器。代码改编自 documentation:
mat A(5,5, fill::randu);
// ... make A triangular here ...
mat B(5,5, fill::randu);
// tell solve() to treat A as upper triangular
// and automatically enable fast mode
mat X = solve(trimatu(A), B);
根据文档,Armadillo 求解器似乎可以自动检测带状矩阵和对称正定矩阵。
Armadillo 已经使用 dtrtrs
解决三角问题。部分代码参考:
dtrtrs
在lapack::trtrs
中被调用:https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/wrapper_lapack.hpp#L908lapack::trtrs
在auxlib::solve_tri
中被调用,带有一个很好的调试语句:https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/auxlib_meat.hpp#L3983
所以如果我们可以触发这个调试语句,我们可以确定 dtrtrs
确实被使用了:
#define ARMA_EXTRA_DEBUG
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
// [[Rcpp::export]]
void testTrisolve() {
arma::mat A = arma::randu<arma::mat>(5,5);
arma::mat B = arma::randu<arma::mat>(5,5);
arma::mat X1 = arma::solve(A, B);
arma::mat X3 = arma::solve(arma::trimatu(A), B);
}
/*** R
testTrisolve()
*/
这会产生很多调试消息,其中包括:
lapack::gesvx()
[...]
lapack::trtrs()
所以我们清楚地看到dtrtrs
用于三对角线的情况
关于你原来的问题:
- 冲突类型错误是 Aramdillo 已经使用
dtrtrs
,但签名略有不同(A
是const
)的结果。 - Fortran 函数的 C 级名称取决于
ARMA_BLAS_UNDERSCORE
和ARMA_USE_WRAPPER
的值。我不确定情况是否总是如此,但对我而言,前者已定义而后者未定义(c.f。config.hpp
),导致dtrtrs_
作为名称。
确实,如果我在 Armadillo 使用它的地方添加一个 const
并将函数调用为 dtrtrs_
,您的代码编译时不会出现错误或警告(除了未使用的变量 ...) :
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
extern "C" {
void arma_fortran(dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
const double* A, int* LDA, double* B, int* LDB, int* INFO);
}
int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
int info = 0;
dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
return info;
}
[...]