具有 unique_ptr 和矩阵的持久表达式模板

Persistent expression templates with unique_ptr and matrices

我想使用表达式模板来创建跨语句持续存在的对象树。构建树最初涉及一些使用 Eigen 线性代数库的计算。持久表达式模板将有额外的方法来通过以不同方式遍历树来计算其他数量(但我还没有)。

为避免临时对象超出范围的问题,子表达式对象通过 std::unique_ptr 进行管理。在构建表达式树时,指针应该向上传播,以便持有根对象的指针可确保所有对象都保持活动状态。由于 Eigen 创建的表达式模板包含对在语句末尾超出范围的临时对象的引用,因此情况变得复杂,因此必须在构建树时评估所有 Eigen 表达式。

下面是一个按比例缩小的实现,当 val 类型是一个包含整数的对象时,它似乎可以工作,但对于 Matrix 类型,它在构造 output_xpr 对象时崩溃。崩溃的原因似乎是 Eigen 的矩阵乘积表达式模板 (Eigen::GeneralProduct) 在使用之前被损坏了。但是,我自己的表达式对象或 GeneralProduct 的析构函数的 none 似乎在崩溃发生之前被调用,并且 valgrind 没有检测到任何无效的内存访问。

任何帮助将不胜感激!我也很欣赏关于我将移动构造函数与静态继承一起使用的评论,也许问题出在某个地方。

#include <iostream>
#include <memory>

#include <Eigen/Core>

typedef Eigen::MatrixXi val;

// expression_ptr and derived_ptr: contain unique pointers
// to the actual expression objects

template<class Derived>
struct expression_ptr {
    Derived &&transfer_cast() && {
        return std::move(static_cast<Derived &&>(*this));
    }
};

template<class A>
struct derived_ptr : public expression_ptr<derived_ptr<A>> {
    derived_ptr(std::unique_ptr<A> &&p) : ptr_(std::move(p)) {}
    derived_ptr(derived_ptr<A> &&o) : ptr_(std::move(o.ptr_)) {}

    auto operator()() const {
        return (*ptr_)();
    }

private:
    std::unique_ptr<A> ptr_;
};

// value_xpr, product_xpr and output_xpr: expression templates
// doing the actual work

template<class A>
struct value_xpr {
    value_xpr(const A &v) : value_(v) {}

    const A &operator()() const {
        return value_;
    }

private:
    const A &value_;
};

template<class A,class B>
struct product_xpr {
    product_xpr(expression_ptr<derived_ptr<A>> &&a, expression_ptr<derived_ptr<B>> &&b) :
        a_(std::move(a).transfer_cast()), b_(std::move(b).transfer_cast()) {
    }

    auto operator()() const {
        return a_() * b_();
    }

private:
    derived_ptr<A> a_;
    derived_ptr<B> b_;
};

// Top-level expression with a matrix to hold the completely
// evaluated output of the Eigen calculations
template<class A>
struct output_xpr {
    output_xpr(expression_ptr<derived_ptr<A>> &&a) :
        a_(std::move(a).transfer_cast()), result_(a_()) {}

    const val &operator()() const {
        return result_;
    }

private:
    derived_ptr<A> a_;
    val result_;
};

// helper functions to create the expressions

template<class A>
derived_ptr<value_xpr<A>> input(const A &a) {
    return derived_ptr<value_xpr<A>>(std::make_unique<value_xpr<A>>(a));
}

template<class A,class B>
derived_ptr<product_xpr<A,B>> operator*(expression_ptr<derived_ptr<A>> &&a, expression_ptr<derived_ptr<B>> &&b) {
    return derived_ptr<product_xpr<A,B>>(std::make_unique<product_xpr<A,B>>(std::move(a).transfer_cast(), std::move(b).transfer_cast()));
}

template<class A>
derived_ptr<output_xpr<A>> eval(expression_ptr<derived_ptr<A>> &&a) {
    return derived_ptr<output_xpr<A>>(std::make_unique<output_xpr<A>>(std::move(a).transfer_cast()));
}

int main() {
    Eigen::MatrixXi mat(2, 2);
    mat << 1, 1, 0, 1;
    val one(mat), two(mat);
    auto xpr = eval(input(one) * input(two));
    std::cout << xpr() << std::endl;
    return 0;
}

您的问题似乎是您正在使用其他人的表达式模板,并将结果存储在 auto

(这发生在 product_xpr<A>::operator() 中,您在这里调用 *,如果我没看错的话,它是使用表达式模板的特征乘法)。

表达式模板通常被设计为假定整个表达式将出现在一行中,并且它将以导致表达式模板被评估的接收器类型(如矩阵)结束。

在你的例子中,你有 a*b 表达式模板,然后用它来构造一个表达式模板 return 值,你稍后会计算它。在 a*b 中传递给 * 的临时对象的生命周期将在您到达接收器类型(矩阵)时结束,这违反了表达式模板的预期。

我正在努力想出一个解决方案来确保所有临时对象的生命周期都得到延长。有人认为我有某种延续传递风格,而不是调用:

Matrix m = (a*b);

你会

auto x = { do (a*b) pass that to (cast to matrix) }

替换

auto operator()() const {
    return a_() * b_();
}

template<class F>
auto operator()(F&& f) const {
    return std::forward<F>(f)(a_() * b_());
}

其中“下一步”被传递给每个子表达式。这对于二进制表达式来说更棘手,因为您必须确保第一个表达式的计算调用导致第二个子表达式被计算的代码,然后将两个表达式组合在一起,都在同一个长递归调用堆栈中。

我对延续传递风格还不够熟练,无法完全解开这个结,但它在函数式编程世界中有点流行。

另一种方法是将您的树展平为可选的元组,然后使用花哨的 operator() 构造树中的每个可选,并以这种方式手动连接参数。基本上对中间值进行手动内存管理。如果 Eigen 表达式模板是移动感知的或没有任何自指针,这将起作用,因此在构造点移动不会破坏事物。写那会很有挑战性。

建议的延续传递风格解决了问题并且不太疯狂(总的来说不比模板元编程更疯狂)。二进制表达式参数的双重 lambda 评估可以隐藏在辅助函数中,请参见下面代码中的 binary_cont。作为参考,并且由于它并非完全微不足道,所以我在此处发布固定代码。

如果有人理解为什么我必须在 binary_cont 中的 F 类型上加上 const 限定符,请告诉我。

#include <iostream>
#include <memory>

#include <Eigen/Core>

typedef Eigen::MatrixXi val;

// expression_ptr and derived_ptr: contain unique pointers
// to the actual expression objects

template<class Derived>
struct expression_ptr {
    Derived &&transfer_cast() && {
        return std::move(static_cast<Derived &&>(*this));
    }
};

template<class A>
struct derived_ptr : public expression_ptr<derived_ptr<A>> {
    derived_ptr(std::unique_ptr<A> &&p) : ptr_(std::move(p)) {}
    derived_ptr(derived_ptr<A> &&o) = default;

    auto operator()() const {
        return (*ptr_)();
    }

    template<class F>
    auto operator()(F &&f) const {
        return (*ptr_)(std::forward<F>(f));
    }

private:
    std::unique_ptr<A> ptr_;
};

template<class A,class B,class F>
auto binary_cont(const derived_ptr<A> &a_, const derived_ptr<B> &b_, const F &&f) {
    return a_([&b_, f = std::forward<const F>(f)] (auto &&a) {
        return b_([a = std::forward<decltype(a)>(a), f = std::forward<const F>(f)] (auto &&b) {
            return std::forward<const F>(f)(std::forward<decltype(a)>(a), std::forward<decltype(b)>(b));
        });
    });
}

// value_xpr, product_xpr and output_xpr: expression templates
// doing the actual work

template<class A>
struct value_xpr {
    value_xpr(const A &v) : value_(v) {}

    template<class F>
    auto operator()(F &&f) const {
        return std::forward<F>(f)(value_);
    }

private:
    const A &value_;
};

template<class A,class B>
struct product_xpr {
    product_xpr(expression_ptr<derived_ptr<A>> &&a, expression_ptr<derived_ptr<B>> &&b) :
        a_(std::move(a).transfer_cast()), b_(std::move(b).transfer_cast()) {
    }

    template<class F>
    auto operator()(F &&f) const {
        return binary_cont(a_, b_,
            [f = std::forward<F>(f)] (auto &&a, auto &&b) {
                return f(std::forward<decltype(a)>(a) * std::forward<decltype(b)>(b));
            });
    }

private:
    derived_ptr<A> a_;
    derived_ptr<B> b_;
};

template<class A>
struct output_xpr {
    output_xpr(expression_ptr<derived_ptr<A>> &&a) :
            a_(std::move(a).transfer_cast()) {
        a_([this] (auto &&x) { this->result_ = x; });
    }

    const val &operator()() const {
        return result_;
    }

private:
    derived_ptr<A> a_;
    val result_;
};

// helper functions to create the expressions

template<class A>
derived_ptr<value_xpr<A>> input(const A &a) {
    return derived_ptr<value_xpr<A>>(std::make_unique<value_xpr<A>>(a));
}

template<class A,class B>
derived_ptr<product_xpr<A,B>> operator*(expression_ptr<derived_ptr<A>> &&a, expression_ptr<derived_ptr<B>> &&b) {
    return derived_ptr<product_xpr<A,B>>(std::make_unique<product_xpr<A,B>>(std::move(a).transfer_cast(), std::move(b).transfer_cast()));
}

template<class A>
derived_ptr<output_xpr<A>> eval(expression_ptr<derived_ptr<A>> &&a) {
    return derived_ptr<output_xpr<A>>(std::make_unique<output_xpr<A>>(std::move(a).transfer_cast()));
}

int main() {
    Eigen::MatrixXi mat(2, 2);
    mat << 1, 1, 0, 1;
    val one(mat), two(mat), three(mat);
    auto xpr = eval(input(one) * input(two) * input(one) * input(two));
    std::cout << xpr() << std::endl;
    return 0;
}