Sympy lambdify ImmutableDenseMatrix 与 numexpr

Sympy lambdify ImmutableDenseMatrix with numexpr

我尝试使用 lambdify 加速 MutableDenseMatrix 的计算。它与模块 'numpy' 一起工作。 'Numexpr' 应该更快(因为我需要评估来解决大型优化问题)。

给出了我正在尝试做的一个较小的例子
from sympy import symbols, cos, Matrix, lambdify

a11, a12, a21, a22, b11, b12, b21, b22, u = symbols("a11 a12 a21 a22 b11 b12 b21 b22 u")
A = Matrix([[a11, a12], [a21, a22]])
B = Matrix([[b11, b12], [b21, b22]])
expr = A * (B ** 2) * cos(u) + A ** (-3 / 2)
f = lambdify((A, B), expr, modules='numexpr')

它引发错误

TypeError: numexpr cannot be used with ImmutableDenseMatrix

有没有办法对 DenseMatrices 使用 lambdify?或者另一个如何加快评估的想法?

提前致谢!

一个使用 numexpr 的可能解决方案是单独计算每个矩阵表达式。以下代码应输出一个 python 函数,该函数使用 Numexpr.

计算所有矩阵表达式

带矩阵的 Numexpr

import numpy as np
import sympy

def lambdify_numexpr(args,expr,expr_name):
    from sympy.printing.lambdarepr import NumExprPrinter as Printer
    printer = Printer({'fully_qualified_modules': False, 'inline': True,'allow_unknown_functions': False})

    s=""
    s+="import numexpr as ne\n"
    s+="from numpy import *\n"
    s+="\n"

    #get arg_names
    arg_names=[]
    arg_names_str=""
    for i in range(len(args)):
        name=[ k for k,v in globals().items() if v is args[i]][0]
        arg_names_str+=name
        arg_names.append(name)

        if i< len(args)-1:
            arg_names_str+=","

    #Write header
    s+="def "+expr_name+"("+arg_names_str+"):\n"

    #unroll array
    for ii in range(len(args)):
        arg=args[ii]
        if arg.is_Matrix:
            for i in range(arg.shape[0]):
                for j in range(arg.shape[1]):
                    s+="    "+ str(arg[i,j])+" = " + arg_names[ii]+"["+str(i)+","+str(j)+"]\n"

    s+="    \n"
    #If the expr is a matrix
    if expr.is_Matrix:
        #write expressions
        for i in range(len(expr)):
            s+="    "+ "res_"+str(i)+" = ne."+printer.doprint(expr[i])+"\n"
            s+="    \n"

        res_counter=0
        #write array
        s+="    return concatenate(("
        for i in range(expr.shape[0]):
            s+="("
            for j in range(expr.shape[1]):
                s+="res_"+str(res_counter)+","
                res_counter+=1
            s+="),"
        s+="))\n"

    #If the expr is not a matrix
    else:
        s+="    "+ "return ne."+printer.doprint(expr)+"\n"
    return s