如何提高循环遍历 numpy 线性求解的计算效率

How can I improve the computational efficiency of looping over numpy linear solve

我有一个方程组需要求解大量变量。代码比下面更复杂,但是问题如下:

我有一个迭代器:

iterator= np.arange(2000)   # array!

以及依赖于该迭代器函数的数组:

A_11 = function_A11(iterator)
A_12 = function_A12(iterator)
A_21 = function_A21(iterator)
A_22 = function_A22(iterator)

B_1 = function_B1(iterator)
B_2 = function_B2(iterator)

X = np.zeros(2, 2000)

for i, (A11, A12, A21, A22, B1, B2) in enumerate(zip(A_11, A_12, A_21, A_22, B_1, B_2):
    A = np.array([[A11, A12],[A21,A22]]) 
    B = np.array([B1, B2])
    X[:,i] = np.linalg.solve(A,B)

该方法有效,但计算量大,我觉得我应该能够优化它,例如通过使用 3D 阵列。有人有什么建议吗?

谢谢!

蒂姆

arr= np.arange(2000)   # array!

A_11 = function_A11(arr)
A_12 = function_A12(arr)
A_21 = function_A21(arr)
A_22 = function_A22(arr)

B_1 = function_B1(arr)
B_2 = function_B2(arr)

合并

AA = np.array([[A_11,A_12],[A_21, A_22]])

我预计 AA.shape 为 (2,2,2000)。

A = AA.transpose(2,0,1)

得到 solve_ivp 可以使用的 (2000,2,2)。

我还没有测试过这个。