NumPy Tensordot 轴=2

NumPy Tensordot axes=2

我知道有很多关于 tensordot 的问题,我已经浏览了 15 页的迷你书答案中的一些,我确定这些答案是人们花了几个小时做出来的,但我还没有找到对什么的解释 axes=2 确实如此。

这让我想到 np.tensordot(b,c,axes=2) == np.sum(b * c),但作为一个数组:

b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,c,axes=2)
Out: array(7532)

但是后来失败了:

a = np.arange(30).reshape((2,3,5))
np.tensordot(a,a,axes=2)

如果有人能对np.tensordot(x,y,axes=2),而且只有axes=2,提供一个简明扼要的解释,那我很乐意接受。

In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
  File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
    np.tensordot(a,a,axes=2)
  File "<__array_function__ internals>", line 5, in tensordot
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

在我之前的 post 中,我推断 axis=2 转换为 axes=([-2,-1],[0,1])

In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
Traceback (most recent call last):
  File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
    np.tensordot(a,a,axes=([-2,-1],[0,1]))
  File "<__array_function__ internals>", line 5, in tensordot
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

因此,它试图对第一个 a 的最后两个维度和第二个 a 的前两个维度进行双轴缩减。这个 a 是尺寸不匹配。显然这个 axes 是为 2d 数组设计的,没有考虑 3d 数组。不是3轴缩小。

这些个位数的轴值是一些开发人员认为方便的东西,但这并不意味着它们经过了严格的思考或测试。

元组轴为您提供更多控制权:

In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
Out[75]: 
array([[ 880,  940, 1000, 1060],
       [ 940, 1006, 1072, 1138],
       [1000, 1072, 1144, 1216],
       [1060, 1138, 1216, 1294]])