混合形状数组列的点积

Dot products of columns of mixed shaped arrays

我正在尝试获取 nx2x3 数组和 nx3 数组中每个元素的点积(n 的值始终在两者之间共享)。

例如:

import numpy as np

a = np.arange(12).reshape(4,3)
b = np.arange(24).reshape(4,2,3)

我试图获取的数组将包含这些:

print(np.dot(b[0],a[0]))
print(np.dot(b[1],a[1]))
print(np.dot(b[2],a[2]))
print(np.dot(b[3],a[3]))

我确定有一种方法可以使用 einsumtensordot,但我无法让它工作。

您可以这样使用 einsum

>>> np.einsum('ij,ikj->ik', a, b)
array([[  5,  14],
       [ 86, 122],
       [275, 338],
       [572, 662]])

这里发生的所有事情是 a 的轴 0 与 b 的轴 0 相乘,a 的轴 1 与 b 的轴 2 相乘.沿后一个轴的值​​相加并返回二维数组。

(tensordot 并不能很好地解决这个问题,因为我们需要沿着 two 轴进行乘法,而只需要沿着 one。这些操作只与tensordot成对出现。)