numpy einsum:嵌套点积

numpy einsum: nested dot products

我有两个 n-by-k-by-3 数组 ab,例如,

import numpy as np

a = np.array([
    [
        [1, 2, 3],
        [3, 4, 5]
    ],
    [
        [4, 2, 4],
        [1, 4, 5]
    ]
    ])
b = np.array([
    [
        [3, 1, 5],
        [0, 2, 3]
    ],
    [
        [2, 4, 5],
        [1, 2, 4]
    ]
    ])

它喜欢计算所有 "triplets" 对的点积,即

np.sum(a*b, axis=2)

更好的方法可能是 einsum,但我似乎无法弄清楚索引。

这里有什么提示吗?

你在这两个 3D 输入数组上丢失了第三个轴,同时保持前两个轴对齐。因此,使用 np.einsum 时,前两个字符串相同,第三个字符串也相同,但在输出字符串符号中会被跳过,这表明我们正在沿该轴减少两个输入。因此,解决方案是 -

np.einsum('ijk,ijk->ij',a,b)