PyTorch 张量沿任意轴的乘积 à la NumPy 的“tensordot”

Product of PyTorch tensors along arbitrary axes à la NumPy's `tensordot`

NumPy 提供了非常有用的 tensordot 函数。它允许您计算两个 ndarrays 沿任何轴(其大小匹配)的乘积。我很难在 PyTorch 中找到类似的东西。 mm 仅适用于二维数组,matmul 具有一些不受欢迎的广播属性。

我错过了什么吗?我真的想重塑阵列以模仿我想要使用的产品吗?mm

如@McLawrence 所述,目前正在讨论此功能 (issue thread)。

同时,您可以考虑torch.einsum(),例如:

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)

原始答案完全正确,但作为更新,Pytorch now supports tensordot 是原生的。与 numpy 相同的调用签名,但将 axes 更改为 dims.

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)