PyTorch 张量沿任意轴的乘积 à la NumPy 的“tensordot”
Product of PyTorch tensors along arbitrary axes à la NumPy's `tensordot`
NumPy 提供了非常有用的 tensordot
函数。它允许您计算两个 ndarrays
沿任何轴(其大小匹配)的乘积。我很难在 PyTorch 中找到类似的东西。 mm
仅适用于二维数组,matmul
具有一些不受欢迎的广播属性。
我错过了什么吗?我真的想重塑阵列以模仿我想要使用的产品吗?mm
?
如@McLawrence 所述,目前正在讨论此功能 (issue thread)。
同时,您可以考虑torch.einsum()
,例如:
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
原始答案完全正确,但作为更新,Pytorch now supports tensordot
是原生的。与 numpy 相同的调用签名,但将 axes
更改为 dims
.
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
NumPy 提供了非常有用的 tensordot
函数。它允许您计算两个 ndarrays
沿任何轴(其大小匹配)的乘积。我很难在 PyTorch 中找到类似的东西。 mm
仅适用于二维数组,matmul
具有一些不受欢迎的广播属性。
我错过了什么吗?我真的想重塑阵列以模仿我想要使用的产品吗?mm
?
如@McLawrence 所述,目前正在讨论此功能 (issue thread)。
同时,您可以考虑torch.einsum()
,例如:
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
原始答案完全正确,但作为更新,Pytorch now supports tensordot
是原生的。与 numpy 相同的调用签名,但将 axes
更改为 dims
.
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)