python 中的 Einsum 用于复杂循环

Einsum in python for a complex loop

我在 python 中有复杂的循环,我试图对其进行“矢量化”以缩短计算时间。我发现函数 np.einsum 允许它,我设法使用它,但我陷入了另一个循环。

在下面的代码中,我将设法“einsumize”(s1) 的循环和另一个我没有的循环放在一起。

import numpy as np

Q = 6
P = 24
N = 40

bQ = np.arange(Q)
bP = np.arange(P)
uN = np.arange(N)
t1 = np.arange(P*Q*N).reshape([Q,P,N])
t2 = np.arange(Q*N*Q*N).reshape([Q,N,Q,N])

s1_ = np.einsum('p,q,n,qpn',bP, bQ, uN, t1)
s1 = 0
for p in range(P):
    for q in range(Q):
        for n in range(N):
            s1 += bP[p] * bQ[q] * uN[n] * t1[q,p,n]
print(s1)
print(s1_)
print()

s2_ = np.einsum('p,q,n,m,pnqm', bQ, bQ, uN, uN, t2)
s2 = 0
for p in range(Q):
    for q in range(Q):
        for n in range(N):
            for m in range(N):
                s2 += bQ[q] * bQ[q] * uN[n] * uN[m] * t2[p,n,q,m]
print(s2)
print(s2_)

前面代码的结果是

13475451600
13475451600

6125547636000
5707354770000

计算 s1 的数学公式是:s1 = \sum_p\sum_q\sum_n bP[p] * bQ[q] * uN[n] * t1[q,p,n]。计算 s2 的是 s2 = \sum_q\sum_q'\sum_n\sum_n' bQ[q] bQ[q'] uN[n] uN[n'] t2[q,n,q',n'].

对于 triple 循环,如果我很好地理解 einsum 是如何工作的,我会告诉将要相乘的张量的不同索引,并告诉没有输出索引告诉所有将被总结。但它似乎不适用于 quadruple 循环。

编辑: 我看到一个答案(似乎已被删除),告诉我这只是四重循环中索引的错误......我应该看到它:/

我发现你的代码中有错别字,你没有在 4 阶张量之外使用变量 p

尝试改变

                s2 += bQ[q] * bQ[q] * uN[n] * uN[m] * t2[p,n,q,m]

为了

                s2 += bQ[p] * bQ[q] * uN[n] * uN[m] * t2[p,n,q,m]