NumPy Tensordot 轴=2
NumPy Tensordot axes=2
我知道有很多关于 tensordot 的问题,我已经浏览了 15 页的迷你书答案中的一些,我确定这些答案是人们花了几个小时做出来的,但我还没有找到对什么的解释 axes=2
确实如此。
这让我想到 np.tensordot(b,c,axes=2) == np.sum(b * c)
,但作为一个数组:
b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,c,axes=2)
Out: array(7532)
但是后来失败了:
a = np.arange(30).reshape((2,3,5))
np.tensordot(a,a,axes=2)
如果有人能对np.tensordot(x,y,axes=2)
,而且只有axes=2
,提供一个简明扼要的解释,那我很乐意接受。
In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
np.tensordot(a,a,axes=2)
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
在我之前的 post 中,我推断 axis=2
转换为 axes=([-2,-1],[0,1])
In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
Traceback (most recent call last):
File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
np.tensordot(a,a,axes=([-2,-1],[0,1]))
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
因此,它试图对第一个 a
的最后两个维度和第二个 a
的前两个维度进行双轴缩减。这个 a
是尺寸不匹配。显然这个 axes
是为 2d 数组设计的,没有考虑 3d 数组。不是3轴缩小。
这些个位数的轴值是一些开发人员认为方便的东西,但这并不意味着它们经过了严格的思考或测试。
元组轴为您提供更多控制权:
In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
Out[75]:
array([[ 880, 940, 1000, 1060],
[ 940, 1006, 1072, 1138],
[1000, 1072, 1144, 1216],
[1060, 1138, 1216, 1294]])
我知道有很多关于 tensordot 的问题,我已经浏览了 15 页的迷你书答案中的一些,我确定这些答案是人们花了几个小时做出来的,但我还没有找到对什么的解释 axes=2
确实如此。
这让我想到 np.tensordot(b,c,axes=2) == np.sum(b * c)
,但作为一个数组:
b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,c,axes=2)
Out: array(7532)
但是后来失败了:
a = np.arange(30).reshape((2,3,5))
np.tensordot(a,a,axes=2)
如果有人能对np.tensordot(x,y,axes=2)
,而且只有axes=2
,提供一个简明扼要的解释,那我很乐意接受。
In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
np.tensordot(a,a,axes=2)
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
在我之前的 post 中,我推断 axis=2
转换为 axes=([-2,-1],[0,1])
In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
Traceback (most recent call last):
File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
np.tensordot(a,a,axes=([-2,-1],[0,1]))
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum
因此,它试图对第一个 a
的最后两个维度和第二个 a
的前两个维度进行双轴缩减。这个 a
是尺寸不匹配。显然这个 axes
是为 2d 数组设计的,没有考虑 3d 数组。不是3轴缩小。
这些个位数的轴值是一些开发人员认为方便的东西,但这并不意味着它们经过了严格的思考或测试。
元组轴为您提供更多控制权:
In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
Out[75]:
array([[ 880, 940, 1000, 1060],
[ 940, 1006, 1072, 1138],
[1000, 1072, 1144, 1216],
[1060, 1138, 1216, 1294]])