如何 return 引用局部变量的 Eigen 表达式?

How to return an Eigen expression that references local variables?

我想创建一个 return 是 Eigen 表达式的函数。该表达式的某些输入是该函数中的局部变量。然而,作为函数 returns,这些局部变量超出了范围。稍后在计算表达式时会读取一些垃圾数据,从而产生错误的结果。

简单示例:

template<typename T>
auto f(const T& x){
    Eigen::ArrayXXd temp = exp(x);
    return temp * (1 - temp);
}

live example

EDIT2:另一个更现实的例子,不仅仅是元素明智的操作:

template <typename T>
auto softmax(const T& x){
    Eigen::ArrayXXd tmp = exp(x - x.maxCoeff());
    return tmp / tmp.sum();
}

有没有办法将这种局部变量的生命周期延长到表达式的生命周期?或者除了在 return 从函数中获取结果之前评估结果之外没有其他方法吗?

编辑:我正在寻找优化现有库的方法。所以我不能过多地更改函数签名 - return 类型必须是 Eigen 表达式。

从 C++14 开始,您可以使用 init-capture listtemp 存储为 lambda 表达式生成的闭包的数据成员:

template<typename T>
auto f2(const T& x) {
    return [temp = exp(x).eval()](){ return temp * (1-temp); };
}

调用的语法很有趣:

Eigen::ArrayXXd y = f2(x)();

但它按预期工作,因为由 f2(x) 创建的临时对象在完整表达式结束时被销毁,所以结果可以由 y = .. 正确评估,因为 temp 仍然是还活着。

Demo

这可以通过在堆上分配这些 "local variables" 来完成。它们需要由返回的表达式拥有,因此一旦返回的表达式超出范围,就会释放内存。为此,我们需要自定义 Eigen 表达式:

template <class ArgType, typename... Ptrs>
class Holder;

namespace Eigen {
namespace internal {
template <class ArgType, typename... Ptrs>
struct traits<Holder<ArgType, Ptrs...>> {
  typedef typename ArgType::StorageKind StorageKind;
  typedef typename traits<ArgType>::XprKind XprKind;
  typedef typename ArgType::StorageIndex StorageIndex;
  typedef typename ArgType::Scalar Scalar;
  enum {
    Flags = ArgType::Flags & RowMajorBit,
    RowsAtCompileTime = ArgType::RowsAtCompileTime,
    ColsAtCompileTime = ArgType::ColsAtCompileTime,
    MaxRowsAtCompileTime = ArgType::MaxRowsAtCompileTime,
    MaxColsAtCompileTime = ArgType::MaxColsAtCompileTime
  };
};
}  // namespace internal
}  // namespace Eigen

template <typename ArgType, typename... Ptrs>
class Holder
    : public Eigen::internal::dense_xpr_base<Holder<ArgType, Ptrs...>>::type {
 public:
  Holder(const ArgType& arg, Ptrs*... pointers)
      : m_arg(arg), m_unique_ptrs(std::unique_ptr<Ptrs>(pointers)...) {}
  typedef typename Eigen::internal::ref_selector<Holder<ArgType, Ptrs...>>::type
      Nested;
  typedef Eigen::Index Index;
  Index rows() const { return m_arg.rows(); }
  Index cols() const { return m_arg.cols(); }
  typedef typename Eigen::internal::ref_selector<ArgType>::type ArgTypeNested;
  ArgTypeNested m_arg;
  std::tuple<std::unique_ptr<Ptrs>...> m_unique_ptrs;
};

namespace Eigen {
namespace internal {
template <typename ArgType, typename... Ptrs>
struct evaluator<Holder<ArgType, Ptrs...>> : evaluator_base<Holder<ArgType>> {
  typedef Holder<ArgType, Ptrs...> XprType;
  typedef typename nested_eval<ArgType, 1>::type ArgTypeNested;
  typedef typename remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
  typedef typename XprType::CoeffReturnType CoeffReturnType;
  enum {
    CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
    Flags = evaluator<ArgTypeNestedCleaned>::Flags
            & (HereditaryBits | LinearAccessBit | PacketAccessBit),
    Alignment = Unaligned & evaluator<ArgTypeNestedCleaned>::Alignment,
  };

  evaluator<ArgTypeNestedCleaned> m_argImpl;

  evaluator(const XprType& xpr) : m_argImpl(xpr.m_arg) {}

  EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
    CoeffReturnType val = m_argImpl.coeff(row, col);
    return val;
  }

  EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
    CoeffReturnType val = m_argImpl.coeff(index);
    return val;
  }

  template <int LoadMode, typename PacketType>
  EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
    return m_argImpl.template packet<LoadMode, PacketType>(row, col);
  }

  template <int LoadMode, typename PacketType>
  EIGEN_STRONG_INLINE PacketType packet(Index index) const {
    return m_argImpl.template packet<LoadMode, PacketType>(index);
  }
};
}  // namespace internal
}  // namespace Eigen

template <typename T, typename... Ptrs>
Holder<T, Ptrs...> makeHolder(const T& arg, Ptrs*... pointers) {
  return Holder<T, Ptrs...>(arg, pointers...);
}

所以问题的功能可以这样实现:

template<typename T>
auto f(const T& x){
    auto* temp = new Eigen::ArrayXXd(exp(x));
    return makeHolder(*temp * (1 - *temp), temp);
}

这非常复杂,并且有额外动态内存分配的另一个缺点。但它确实解决了问题。