用 numpy 推广矩阵乘法

Generalize matrix multiplication with numpy

我有以下代码片段:

import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]

我如何概括任何 a(n,3,3)b(n,3)c(n,3)

我认为 einsum 是正确的方法,但我不太清楚正确的语法...

您可以广播或使用 einsum(更好的 einsum):

import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]

res_broad = (a*b[:,None,:]).sum(2)

res_ein = np.einsum('ijk,ik->ij',a,b)

print(f"broadcast works: {np.allclose(c,res_broad)}")
print(f"einsum works: {np.allclose(c,res_broad)}")