NumPy tensordot分组计算
NumPy tensordot grouped calculation
假设我有两个数组:
import numpy as np
a=np.array([[1,2],
[3,4]])
b=np.array([[1,2],
[3,4]])
我想按元素乘以数组然后对元素求和,即1*1 + 2*2 + 3*3 + 4*4 = 30
,我可以使用:
np.tensordot(a, b, axes=((-2,-1),(-2,-1)))
>>> array(30)
现在,假设数组 a
和 b
是 2×2×2 数组:
a=np.array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
b=np.array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
并且我想对每个组执行相同的操作,即 [[1,2],[3,4]]
次 [[1,2],[3,4]]
然后对元素求和,[[5,6],[7,8]]
也是如此。结果应为 array([ 30, 174])
,其中 30 = 1*1 + 2*2 + 3*3 + 4*4
和 174 = 5*5 + 6*6 + 7*7 + 8*8
。有没有办法使用 tensordot 来做到这一点?
P.S.
我知道在这种情况下你可以简单地使用 sum 或 einsum:
np.sum(a*b,axis=(-2,-1))
>>> array([ 30, 174])
np.einsum('ijk,ijk->i',a,b)
>>> array([ 30, 174])
但这只是一个简化的例子,我需要使用 tensordot
因为它更快。
感谢您的帮助!!
您可以使用:np.diag(np.tensordot(a, b, axes=((1, 2), (1, 2))))
来获得您想要的结果。但是,在您的情况下,使用 np.tensordot
或矩阵乘法不是一个好主意,因为它们所做的工作比需要的多得多。它们被有效实现的事实并不能平衡它们所做的计算比需要的多得多的事实(这里只有对角线有用)。 np.einsum('ijk,ijk->i',a,b)
计算的内容不会超出您的情况。您可以尝试 optimize=True
甚至 optimize='optimal'
,因为参数 optimize
默认设置为 False
。如果这不够快,您可以尝试使用 NumExpr 以便更有效地计算 np.sum(a*b,axis=(1, 2))
(可能是并行的)。或者,您也可以使用 Numba 或 Cython。两者都支持快速并行循环。
假设我有两个数组:
import numpy as np
a=np.array([[1,2],
[3,4]])
b=np.array([[1,2],
[3,4]])
我想按元素乘以数组然后对元素求和,即1*1 + 2*2 + 3*3 + 4*4 = 30
,我可以使用:
np.tensordot(a, b, axes=((-2,-1),(-2,-1)))
>>> array(30)
现在,假设数组 a
和 b
是 2×2×2 数组:
a=np.array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
b=np.array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
并且我想对每个组执行相同的操作,即 [[1,2],[3,4]]
次 [[1,2],[3,4]]
然后对元素求和,[[5,6],[7,8]]
也是如此。结果应为 array([ 30, 174])
,其中 30 = 1*1 + 2*2 + 3*3 + 4*4
和 174 = 5*5 + 6*6 + 7*7 + 8*8
。有没有办法使用 tensordot 来做到这一点?
P.S.
我知道在这种情况下你可以简单地使用 sum 或 einsum:
np.sum(a*b,axis=(-2,-1))
>>> array([ 30, 174])
np.einsum('ijk,ijk->i',a,b)
>>> array([ 30, 174])
但这只是一个简化的例子,我需要使用 tensordot
因为它更快。
感谢您的帮助!!
您可以使用:np.diag(np.tensordot(a, b, axes=((1, 2), (1, 2))))
来获得您想要的结果。但是,在您的情况下,使用 np.tensordot
或矩阵乘法不是一个好主意,因为它们所做的工作比需要的多得多。它们被有效实现的事实并不能平衡它们所做的计算比需要的多得多的事实(这里只有对角线有用)。 np.einsum('ijk,ijk->i',a,b)
计算的内容不会超出您的情况。您可以尝试 optimize=True
甚至 optimize='optimal'
,因为参数 optimize
默认设置为 False
。如果这不够快,您可以尝试使用 NumExpr 以便更有效地计算 np.sum(a*b,axis=(1, 2))
(可能是并行的)。或者,您也可以使用 Numba 或 Cython。两者都支持快速并行循环。