numpy.einsum 表达式的含义/等价物

Meaning / equivalent of numpy.einsum expression

我正在拼命寻找 python 内置等价于以下 numpy.einsum 表达式:

>>> a = np.array((((1, 2), (3, 4)), ((5, 6), (7, 8))))
>>> a
array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]])

>>> b = np.array((((9, 10), (11, 12)), ((13, 14), (15, 16))))
>>> b
array([[[ 9, 10],
        [11, 12]],

       [[13, 14],
        [15, 16]]])

>>> np.einsum("abc,abd->dc", a, b)
array([[212, 260],
       [228, 280]])

正如@AlexRiley 评论的那样,直接翻译是这样的:

(a[...,None,:]*b[...,None]).sum((0,1))

让我们解析规范字符串 'abc,abd->dc' 并将术语重命名为 x 和 y,这样它们就不会与索引冲突:

这是读作结果dc = ∑ab xabc yabd

如您所见,索引是从规范字符串中逐字获取的。结果规范中未出现的索引被求和。就是这样。

旁注:我们可以做得更好:合并前两个轴,表达式可以作为矩阵乘积读取,numpy 使用高度优化的代码路径:

b.reshape(-1,b.shape[-1]).T@a.reshape(-1,a.shape[-1])

这比直接翻译快两倍多,也比原来的快一点einsum