任意形状 NumPy 数组的点积
Dot product for arbitrary shaped NumPy arrays
给定两个 numpy.ndarray
对象,A
和 B
,任意形状,我想计算一个 numpy.ndarray
C
属性 即 C[i] == np.dot(A[i], B[i])
所有 i
。我该怎么做?
例1:A.shape==(2,3,4)
和B.shape==(2,4,5)
,那么我们应该有C.shape==(2,3,5)
。
例子2:A.shape==(2,3,4)
和B.shape==(2,4)
,那么我们应该有C.shape==(2,3)
.
假设你想要 dot
的普通矩阵乘法(不是,比如说,矩阵向量或奇怪的废话 dot
对更高的维度),那么足够新的 NumPy 版本(1.10+)让你做
C = numpy.matmul(A, B)
和足够新的 Python 版本 (3.5+) 让你写成
C = A @ B
假设您的 NumPy 也足够新。
这是一个通用的解决方案,可以使用一些 reshaping
和 np.einsum
来涵盖所有类型的案例/任意形状。 einsum
在这里有所帮助,因为我们需要沿输入数组的第一个轴对齐并沿最后一个轴进行缩减。实现看起来像这样 -
def dotprod_axis0(A,B):
N,nA,nB = A.shape[0], A.shape[-1], B.shape[1]
Ar = A.reshape(N,-1,nA)
Br = B.reshape(N,nB,-1)
return np.squeeze(np.einsum('ijk,ikl->ijl',Ar,Br))
案例
我。 A:2D,B:2D
In [119]: # Inputs
...: A = np.random.randint(0,9,(3,4))
...: B = np.random.randint(0,9,(3,4))
...:
In [120]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
33
86
48
In [121]: dotprod_axis0(A,B)
Out[121]: array([33, 86, 48])
二. A:3D,B:3D
In [122]: # Inputs
...: A = np.random.randint(0,9,(2,3,4))
...: B = np.random.randint(0,9,(2,4,5))
...:
In [123]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[[ 74 70 53 118 43]
[ 47 43 29 95 30]
[ 41 37 26 23 15]]
[[ 50 86 33 35 82]
[ 78 126 40 124 140]
[ 67 88 35 47 83]]
In [124]: dotprod_axis0(A,B)
Out[124]:
array([[[ 74, 70, 53, 118, 43],
[ 47, 43, 29, 95, 30],
[ 41, 37, 26, 23, 15]],
[[ 50, 86, 33, 35, 82],
[ 78, 126, 40, 124, 140],
[ 67, 88, 35, 47, 83]]])
三. A:3D,B:2D
In [125]: # Inputs
...: A = np.random.randint(0,9,(2,3,4))
...: B = np.random.randint(0,9,(2,4))
...:
In [126]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[ 87 105 53]
[152 135 120]
In [127]: dotprod_axis0(A,B)
Out[127]:
array([[ 87, 105, 53],
[152, 135, 120]])
四. A:2D,B:3D
In [128]: # Inputs
...: A = np.random.randint(0,9,(2,4))
...: B = np.random.randint(0,9,(2,4,5))
...:
In [129]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[76 93 31 75 16]
[ 33 98 49 117 111]
In [130]: dotprod_axis0(A,B)
Out[130]:
array([[ 76, 93, 31, 75, 16],
[ 33, 98, 49, 117, 111]])
给定两个 numpy.ndarray
对象,A
和 B
,任意形状,我想计算一个 numpy.ndarray
C
属性 即 C[i] == np.dot(A[i], B[i])
所有 i
。我该怎么做?
例1:A.shape==(2,3,4)
和B.shape==(2,4,5)
,那么我们应该有C.shape==(2,3,5)
。
例子2:A.shape==(2,3,4)
和B.shape==(2,4)
,那么我们应该有C.shape==(2,3)
.
假设你想要 dot
的普通矩阵乘法(不是,比如说,矩阵向量或奇怪的废话 dot
对更高的维度),那么足够新的 NumPy 版本(1.10+)让你做
C = numpy.matmul(A, B)
和足够新的 Python 版本 (3.5+) 让你写成
C = A @ B
假设您的 NumPy 也足够新。
这是一个通用的解决方案,可以使用一些 reshaping
和 np.einsum
来涵盖所有类型的案例/任意形状。 einsum
在这里有所帮助,因为我们需要沿输入数组的第一个轴对齐并沿最后一个轴进行缩减。实现看起来像这样 -
def dotprod_axis0(A,B):
N,nA,nB = A.shape[0], A.shape[-1], B.shape[1]
Ar = A.reshape(N,-1,nA)
Br = B.reshape(N,nB,-1)
return np.squeeze(np.einsum('ijk,ikl->ijl',Ar,Br))
案例
我。 A:2D,B:2D
In [119]: # Inputs
...: A = np.random.randint(0,9,(3,4))
...: B = np.random.randint(0,9,(3,4))
...:
In [120]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
33
86
48
In [121]: dotprod_axis0(A,B)
Out[121]: array([33, 86, 48])
二. A:3D,B:3D
In [122]: # Inputs
...: A = np.random.randint(0,9,(2,3,4))
...: B = np.random.randint(0,9,(2,4,5))
...:
In [123]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[[ 74 70 53 118 43]
[ 47 43 29 95 30]
[ 41 37 26 23 15]]
[[ 50 86 33 35 82]
[ 78 126 40 124 140]
[ 67 88 35 47 83]]
In [124]: dotprod_axis0(A,B)
Out[124]:
array([[[ 74, 70, 53, 118, 43],
[ 47, 43, 29, 95, 30],
[ 41, 37, 26, 23, 15]],
[[ 50, 86, 33, 35, 82],
[ 78, 126, 40, 124, 140],
[ 67, 88, 35, 47, 83]]])
三. A:3D,B:2D
In [125]: # Inputs
...: A = np.random.randint(0,9,(2,3,4))
...: B = np.random.randint(0,9,(2,4))
...:
In [126]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[ 87 105 53]
[152 135 120]
In [127]: dotprod_axis0(A,B)
Out[127]:
array([[ 87, 105, 53],
[152, 135, 120]])
四. A:2D,B:3D
In [128]: # Inputs
...: A = np.random.randint(0,9,(2,4))
...: B = np.random.randint(0,9,(2,4,5))
...:
In [129]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[76 93 31 75 16]
[ 33 98 49 117 111]
In [130]: dotprod_axis0(A,B)
Out[130]:
array([[ 76, 93, 31, 75, 16],
[ 33, 98, 49, 117, 111]])