从 sympy 生成优化的八度代码

generating optimized octave code from sympy

我有一些巨大的矩阵要导出,其中仅包含 sin(q)、cos(q) 和 sums/muls。 Sympy 可以计算并将其导出到八度音阶——太棒了! 但是,由于这些是大矩阵,我需要某种 cse 甚至更好的专门优化。

我找到了 this great tutorial for C code with cse。所以我尝试自己移植它,但我在打印机 class 的一些细节上失败了。我认为这是一个无限递归导致 RecursionError: maximum recursion depth exceeded.

我的问题是:是否有 sympy-octave 代码生成和优化结合在一起的示例?或者谁能​​帮我拿到附件的mwe 运行?

import sympy as sp
t = sp.symbols('t')

from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):

    def _print_ImmutableDenseMatrix(self, expr):
        sub_exprs, simplified = sp.cse(expr)
        lines = []
        for var, sub_expr in sub_exprs:
            lines.append( self._print(Assignment(var, sub_expr)))
        M = sp.MatrixSymbol('M', *expr.shape)
        return '\n'.join(lines) + '\n' + self._print(Assignment(M, expr))

tmp = sp.sin(t)+sp.sin(t)**2
tmp = sp.ImmutableDenseMatrix((1,1,tmp))
se, ex = sp.cse(tmp)
print((ex,se))
print('\n')
#tmp = sp.Matrix([2*sp.sin(t),sp.sin(t)])
p = matlabMatrixPrinter()
print(p.doprint(tmp))

编辑:我现在想通了,return 语句中的第二个赋值也运行函数 _print_ImmutableDenseMatrix,所以这最终成为一个递归。我不知道为什么在教程中这对 C 代码没有问题,但在这里它递归运行。这似乎只是简化表达式本身的问题,无法调用 self._print 函数。也许有人对这些打印机有所了解,知道应该如何打印矩阵和这个单一作业?!

经过大量实验后,我觉得我仍然只了解 codePrinter 有意工作流程背后的一些意图。然而,我写了一个完全符合我预期的子类(小心,因为这可能不适用于矩阵以外的任何东西!)。

也许这对某人有用!对我来说,它肯定验证了 sympy 作为一个工作工具,否则成千上万的 sin 评估将是绝对不可行的代码。

我仍然对某人的评论和想法非常感兴趣,谁知道应该如何实现这些功能!

import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
    def print2(self,expr_list,names=None):
        sub_exprs, simplified = sp.cse(expr_list)
        lines = []
        for var, sub_expr in sub_exprs:
            lines.append(self._print(Assignment(var, sub_expr)))
        lines.append('')
        for k,expr in enumerate(simplified):
            if names:
                M = sp.MatrixSymbol(names[k],*expr.shape)
            else:
                M = sp.MatrixSymbol('M{k}'.format(k=k), *expr.shape)
            lines.append(self._print(Assignment(M,expr)))
        result = ''
        return '\n'.join(lines)

tmp = sp.Matrix([sp.sin(t)+sp.sin(t)**2 ])
tmp2 = sp.Matrix([sp.sin(t),sp.cos(t),2*sp.sin(t),sp.cos(t)**2])

p = matlabMatrixPrinter()
#print(p.print2([tmp,tmp2]))
print(p.print2([tmp,tmp2],['scalar_matrix','matrix']));

这给出了预期的输出:

x0 = sin(t);
x1 = cos(t);
scalar_matrix = x0.^2 + x0;
matrix = [x0; x1; 2*x0; x1.^2];

如上所述:使用风险自负:)