接受特征密集矩阵和稀疏矩阵的函数

Function that accepts both Eigen Dense and Sparse Matrices

我正在努力为开源数学库添加稀疏矩阵支持,并且希望 DenseSparse 矩阵类型都没有重复的函数。

下面的示例显示了一个 add 函数。一个具有两个功能的工作示例,然后是两次失败的尝试。下面提供了代码示例的神栓 link。

我查看了关于编写采用 Eigen 类型的函数的 Eigen 文档,但他们使用 Eigen::EigenBase 的答案不起作用,因为 MatrixBaseSparseMatrixBase 都有特定的可用方法EigenBase

中不存在的

https://eigen.tuxfamily.org/dox/TopicFunctionTakingEigenTypes.html

我们使用 C++14,非常感谢您的帮助!

#include <Eigen/Core>
#include <Eigen/Sparse>
#include <iostream>

// Sparse matrix helper
using triplet_d = Eigen::Triplet<double>;
using sparse_mat_d = Eigen::SparseMatrix<double>;
std::vector<triplet_d> tripletList;

// Returns plain object
template <typename Derived>
using eigen_return_t = typename Derived::PlainObject;

// Below two are the generics that work
template <class Derived>
eigen_return_t<Derived> add(const Eigen::MatrixBase<Derived>& A) {
    return A + A;
}

template <class Derived>
eigen_return_t<Derived> add(const Eigen::SparseMatrixBase<Derived>& A) {
    return A + A;
}

int main()
{
  // Fill up the sparse and dense matrices
  tripletList.reserve(4);
  tripletList.push_back(triplet_d(0, 0, 1));
  tripletList.push_back(triplet_d(0, 1, 2));
  tripletList.push_back(triplet_d(1, 0, 3));
  tripletList.push_back(triplet_d(1, 1, 4));

  sparse_mat_d mat(2, 2);
  mat.setFromTriplets(tripletList.begin(), tripletList.end());

  Eigen::Matrix<double, -1, -1> v(2, 2);
  v << 1, 2, 3, 4;

  // Works fine
  sparse_mat_d output = add(mat * mat);
  std::cout << output;

  // Works fine
  Eigen::Matrix<double, -1, -1> output2 = add(v * v);
  std::cout << output2;

} 

我只想拥有一个同时接受稀疏矩阵和密集矩阵的函数,而不是两个加法函数,但是下面的尝试没有成功。

模板模板类型

我的尝试显然很糟糕,但是用模板模板类型替换上面的两个 add 函数会导致模棱两可的基础 class 错误。

template <template <class> class Container, class Derived>
Container<Derived> add(const Container<Derived>& A) {
    return A + A;    
}

错误:

<source>: In function 'int main()':
<source>:35:38: error: no matching function for call to 'add(const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>)'
   35 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:20:20: note: candidate: 'template<template<class> class Container, class Derived> Container<Derived> add(const Container<Derived>&)'
   20 | Container<Derived> add(const Container<Derived>& A) {
      |                    ^~~
<source>:20:20: note:   template argument deduction/substitution failed:
<source>:35:38: note:   'const Container<Derived>' is an ambiguous base class of 'const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>'
   35 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:40:52: error: no matching function for call to 'add(const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>)'
   40 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
      |                                                    ^
<source>:20:20: note: candidate: 'template<template<class> class Container, class Derived> Container<Derived> add(const Container<Derived>&)'
   20 | Container<Derived> add(const Container<Derived>& A) {
      |                    ^~~
<source>:20:20: note:   template argument deduction/substitution failed:
<source>:40:52: note:   'const Container<Derived>' is an ambiguous base class of 'const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>'
   40 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
      |                                                    ^

我相信这是同一个钻石继承问题:

https://www.fluentcpp.com/2017/05/19/crtp-helper/

使用std::conditional_t

下面尝试使用 conditional_t 来推断正确的输入类型

#include <Eigen/Core>
#include <Eigen/Sparse>
#include <iostream>

// Sparse matrix helper
using triplet_d = Eigen::Triplet<double>;
using sparse_mat_d = Eigen::SparseMatrix<double>;
std::vector<triplet_d> tripletList;


// Returns plain object
template <typename Derived>
using eigen_return_t = typename Derived::PlainObject;

// Check it Object inherits from DenseBase
template<typename Derived>
using is_dense_matrix_expression = std::is_base_of<Eigen::DenseBase<std::decay_t<Derived>>, std::decay_t<Derived>>;

// Check it Object inherits from EigenBase
template<typename Derived>
using is_eigen_expression = std::is_base_of<Eigen::EigenBase<std::decay_t<Derived>>, std::decay_t<Derived>>;

// Alias to deduce if input should be Dense or Sparse matrix
template <typename Derived>
using eigen_matrix = typename std::conditional_t<is_dense_matrix_expression<Derived>::value,
 typename Eigen::MatrixBase<Derived>, typename Eigen::SparseMatrixBase<Derived>>;

template <typename Derived>
eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
    return A + A;
}

int main()
{
  tripletList.reserve(4);

  tripletList.push_back(triplet_d(0, 0, 1));
  tripletList.push_back(triplet_d(0, 1, 2));
  tripletList.push_back(triplet_d(1, 0, 3));
  tripletList.push_back(triplet_d(1, 1, 4));

  sparse_mat_d mat(2, 2);
  mat.setFromTriplets(tripletList.begin(), tripletList.end());
  sparse_mat_d output = add(mat * mat);

  std::cout << output;
  Eigen::Matrix<double, -1, -1> v(2, 2);
  v << 1, 2, 3, 4;
  Eigen::Matrix<double, -1, -1> output2 = add(v * v);
  std::cout << output2;

} 

这会引发错误

<source>: In function 'int main()':
<source>:94:38: error: no matching function for call to 'add(const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>)'
   94 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:79:25: note: candidate: 'template<class Derived> eigen_return_t<Derived> add(eigen_matrix<Derived>&)'
   79 | eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
      |                         ^~~
<source>:79:25: note:   template argument deduction/substitution failed:
<source>:94:38: note:   couldn't deduce template parameter 'Derived'
   94 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:99:52: error: no matching function for call to 'add(const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>)'
   99 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
      |                                                    ^
<source>:79:25: note: candidate: 'template<class Derived> eigen_return_t<Derived> add(eigen_matrix<Derived>&)'
   79 | eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
      |                         ^~~
<source>:79:25: note:   template argument deduction/substitution failed:
<source>:99:52: note:   couldn't deduce template parameter 'Derived'
   99 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);

这好像是因为依赖类型的依赖参数不能这样推导link过了

https://deque.blog/2017/10/12/why-template-parameters-of-dependent-type-names-cannot-be-deduced-and-what-to-do-about-it/

神箭示例

下面的 godbolt 可以使用上面的所有实例

https://godbolt.org/z/yKEAsn

有什么方法可以只用一个函数而不是两个?我们有很多函数可以同时支持稀疏矩阵和稠密矩阵,所以最好避免代码重复。

编辑:可能的答案

@Max Langhof 建议使用

template <class Mat>
auto add(const Mat& A) {
 return A + A; 
}

auto 关键字对于 Eigen

有点危险

https://eigen.tuxfamily.org/dox/TopicPitfalls.html

但是

template <class Mat> 
typename Mat::PlainObject add(const Mat& A) { 
    return A + A; 
}

有效,虽然我不完全确定为什么 returning 普通对象在这种情况下有效

编辑编辑

有几个人提到了 auto 关键字的使用。可悲的是,Eigen 不能很好地与 auto 一起使用,如第二个 C++11 和 auto 在下面的 link 中所引用的

https://eigen.tuxfamily.org/dox/TopicPitfalls.html

在某些情况下可以使用 auto,但我想看看是否有一种通用的 auto 左右的方式来投诉 Eigen 的模板 return 类型

有关自动段错误的示例,您可以尝试将添加替换为

template <typename T1>
auto add(const T1& A) 
{
    return ((A+A).eval()).transpose();
}

你编译器的问题如下:

couldn't deduce template parameter 'Derived'

Derived 传递所需的类型应该可以,如下所示:

add<double>(v * v)

不过我不确定,因为在我看来 Eigen::MatrixEigen::MatrixBase 不是同一类型。

但是,如果你对编译器的类型限制少一些,它就能找出类型:

template <typename T>
auto add(const T& A) {
    return A + A;
}

编辑:

刚刚在评论中看到这个解决方案已经发布,Eigen文档建议不要使用auto。我不熟悉 Eigen,但在我看来,浏览文档时,Eigen 可能会产生代表表达式的结果——例如将矩阵加法表示为算法的对象;不是矩阵加法结果本身。在这种情况下,如果您知道 A + A 导致类型 T(在我看来它实际上应该用于 operator+),您可以像下面这样写:

template <typename T>
T add(const T& A) {
    return A + A;
}

在矩阵示例中,这应该强制返回矩​​阵结果;不是表示表达式的对象。但是,由于您最初使用的是 eigen_result_t,因此我不能 100% 确定。

如果你想传递 EigenBase<Derived>,你可以使用 .derived() 提取底层类型(本质上,这只是转换为 Derived const&):

template <class Derived>
eigen_return_t<Derived> add(const Eigen::EigenBase<Derived>& A_) {
    Derived const& A = A_.derived();
    return A + A;
}

更高级,对于这个特定的例子,因为你使用了两次 A,你可以使用内部计算器结构来表达:

template <class Derived>
eigen_return_t<Derived> add2(const Eigen::EigenBase<Derived>& A_) {
    // A is used twice:
    typedef typename Eigen::internal::nested_eval<Derived,2>::type NestedA;
    NestedA A (A_.derived());
    return A + A;
}

这样做的好处是,当将产品作为 A_ 传递时,它不会在评估 A+A 时被评估两次,但如果 A_ 类似于 Block<...> 它不会被不必要地复制。但是,并不真正推荐使用 internal 功能(API 可能随时更改)。

我还没有完全理解你的代码和评论。无论如何,您的问题似乎已简化为找到一种方法来编写可以接受多种矩阵类型的函数。

template <typename T>
auto add(const T& A)
{
    return 2*A;
}

您还可以添加 2 个不同类型的矩阵:

template <typename T1, typename T2>
auto add(const T1& A, const T2& B) -> decltype(A+B) // decltype can be omitted since c++14
{
    return A + B;
}

然后,add(A,A) 给出与 add(A) 相同的结果。但是我认为带有 2 个参数的 add 函数更有意义。而且它更通用,因为您可以将稀疏矩阵与密集矩阵相加。

int main()
{
    constexpr size_t size = 10;
    Eigen::SparseMatrix<double> spm_heap(size,size);
    Eigen::MatrixXd m_heap(size,size);
    Eigen::Matrix<double,size,size> m_stack; 

    // fill the matrices

    std::cout << add(spm_heap,m_heap);
    std::cout << add(spm_heap,m_stack);

    return 0;
}

编辑

关于您声明 auto 不应与 Eigen 一起使用的编辑。这很有趣!

template <typename T>
auto add(const T& A) 
{
    return ((A+A).eval()).transpose();
}

这会产生 segfault。为什么? auto 确实很好地推导了类型,但推导的类型不是 decltype(A),而是该类型的 reference。为什么?我首先认为这是因为 return 值周围的括号(如果有兴趣请阅读 here),但它似乎是由于 transpose 函数的 return 类型.

不管怎样,这个问题很容易解决。正如您所建议的,您可以删除 auto:

template <typename T>
T add(const T& A) 
{
    return ((A+A).eval()).transpose();
}

或者,您可以使用 auto 但指定所需的 return 类型:

template <typename T>
auto add(const T& A) -> typename std::remove_reference<decltype(A)>::type // or simply decltype(A.eval())
{
    return ((A+A).eval()).transpose();
}

现在,对于这个特定的 add 函数,第一个选项(省略 auto)是最佳解决方案。然而,对于另一个接受 2 个不同类型参数的 add 函数,这是一个很好的解决方案:

template <typename T1, typename T2>
auto add(const T1& A, const T2& B) -> decltype((A+B).eval())
{
    return ((A+B).eval()).transpose();
}