在RcppArmadillo中直接调用LAPACK例程

Calling LAPACK routine directly in RcppArmadillo

由于犰狳 (afaik) 没有三角求解器,我想使用 dtrtrs. I have looked at the following two (first, second) 中可用的 LAPACK 三角求解器) SO 线程并将某些东西拼凑在一起,但它不起作用。

我使用 RStudio 创建了一个全新的包,同时还启用了 RcppArmadillo。我有一个头文件 header.h:

#include <RcppArmadillo.h>

#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
#define arma_dtrtrs dtrtrs
#else
#define arma_dtrtrs DTRTRS
#endif
#endif

extern "C" {
  void arma_fortran(arma_dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
                    double* A, int* LDA, double* B, int* LDB, int* INFO);
}

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb);

static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x);

这基本上是第一个链接问题的答案,还有一个包装函数和主要函数。函数的内容在 trisolve.cpp 中,如下所示:

#include "header.h"

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
  int info = 0;
  wrapper_dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
  return info;
}


static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x) {
  size_t  rows = in_A.n_rows;
  size_t  cols = in_A.n_cols;

  double *A = new double[rows*cols];
  double *b = new double[in_b.size()];

  //Lapack has column-major order
  for(size_t col=0, D1_idx=0; col<cols; ++col)
  {
    for(size_t row = 0; row<rows; ++row)
    {
      // Lapack uses column major format
      A[D1_idx++] = in_A(row, col);
    }
    b[col] = in_b(col);
  }

  for(size_t row = 0; row<rows; ++row)
  {
    b[row] = in_b(row);
  }

  int info = trtrs('U', 'N', 'N', cols, 1, A, rows, b, rows);

  for(size_t col=0; col<cols; col++) {
    out_x(col)=b[col];
  }

  delete[] A;
  delete[] b;

  return 0;
}


// [[Rcpp::export]]

arma::mat RtoRcpp(arma::mat A, arma::mat b) {
  arma::uword n = A.n_rows;
  arma::mat x = arma::mat(n, 1, arma::fill::zeros);

  int info = trisolve(A, b, x);
  return x;
}

我有(至少)两个问题:

  1. 尝试编译时,我从头文件中得到:conflicting types for 'dtrtrs_'。但是,我看不出输入有什么问题(这实际上是从第二个链接线程中复制的)。
  2. 不出所料,wrapper_dtrtrts_ 是不正确的。但是据我从 Armadillo 的 compiler_setup.hpp 中得知,arma_fortran 应该为我创建一个名为 wrapper_dtrtrs_ 的函数。我应该在主 cpp 文件中使用什么名称?

犰狳已经有一个三角求解器。代码改编自 documentation:

mat A(5,5, fill::randu);
// ... make A triangular here ...

mat B(5,5, fill::randu);

// tell solve() to treat A as upper triangular
// and automatically enable fast mode
mat X = solve(trimatu(A), B);

根据文档,Armadillo 求解器似乎可以自动检测带状矩阵和对称正定矩阵。

Armadillo 已经使用 dtrtrs 解决三角问题。部分代码参考:

所以如果我们可以触发这个调试语句,我们可以确定 dtrtrs 确实被使用了:

#define ARMA_EXTRA_DEBUG
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

// [[Rcpp::export]]
void testTrisolve() {
  arma::mat A = arma::randu<arma::mat>(5,5);
  arma::mat B = arma::randu<arma::mat>(5,5);

  arma::mat X1 = arma::solve(A, B);
  arma::mat X3 = arma::solve(arma::trimatu(A), B);
}

/*** R
testTrisolve()
*/

这会产生很多调试消息,其中包括:

lapack::gesvx()
[...]
lapack::trtrs()

所以我们清楚地看到dtrtrs用于三对角线的情况

关于你原来的问题:

  1. 冲突类型错误是 Aramdillo 已经使用 dtrtrs,但签名略有不同(Aconst)的结果。
  2. Fortran 函数的 C 级名称取决于 ARMA_BLAS_UNDERSCOREARMA_USE_WRAPPER 的值。我不确定情况是否总是如此,但对我而言,前者已定义而后者未定义(c.f。config.hpp),导致 dtrtrs_ 作为名称。

确实,如果我在 Armadillo 使用它的地方添加一个 const 并将函数调用为 dtrtrs_,您的代码编译时不会出现错误或警告(除了未使用的变量 ...) :

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

extern "C" {
  void arma_fortran(dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
                    const double* A, int* LDA, double* B, int* LDB, int* INFO);
}

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
  int info = 0;
  dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
  return info;
}

[...]