numpy 中是否有简化形式的点积?
Is there a reduced form of the dot product in numpy?
我正在尝试找到一个 numpy 运算,该运算给出 3d 数组索引 i 处的 2d 数组的所有向量与 2d 数组索引 i 处的向量之间的标量积。让我举个例子来解释我的想法:
x = np.array([[[1,2,3],
[2,3,4]],
[[11,12,13],
[12,13,14]]])
y = np.array([[1,1,1],
[2,2,2]])
np.?operation?(x,y.T)
output:
[[[1 *1 + 1 *2 + 1 *3],
[1 *2 + 1 *3 + 1 *4]],
[[2 *11 + 2 *12 + 2 *13],
[2 *12 + 2* 13 + 2 *14]]]
= [[[6],
[9]],
[[72],
[78]]]
如您所见,我基本上是在寻找简化的点积运算。 x 和 y 的点积将产生以下结果:
np.dot(x, y.T)
output:
[[[ 6 12]
[ 9 18]]
[[36 72]
[39 78]]]
或者有没有办法从点积结果中提取我需要的结果?
我也试过 np.tensordot(x,y,axis) 但我无法确定我应该为 -axis- 放置哪些元组。
我也遇到过 np.einsum() 操作,但无法理解这如何帮助我解决问题。
它应该可以用 np.einsum
或 np.matmul/@
来实现,它们在前导维度上有一个“批处理”操作。但是整理尺寸并获得 (2,2,1) 形状有点棘手。
你的 np.dot(x, y.T)
给出了你想要的数字,但你必须在 2 上提取一种对角线,同时保留一个维度。
这是执行此操作的一种方法 - 它不是最快或最简洁的,但应该可以帮助我了解尺寸。
In [432]: y[:,None,:]
Out[432]:
array([[[1, 1, 1]],
[[2, 2, 2]]])
In [433]: y[:,None,:].repeat(2,1)
Out[433]:
array([[[1, 1, 1],
[1, 1, 1]],
[[2, 2, 2],
[2, 2, 2]]])
In [435]: x*y[:,None,:].repeat(2,1)
Out[435]:
array([[[ 1, 2, 3],
[ 2, 3, 4]],
[[22, 24, 26],
[24, 26, 28]]])
In [436]: (x*y[:,None,:].repeat(2,1)).sum(axis=-1, keepdims=True)
Out[436]:
array([[[ 6],
[ 9]],
[[72],
[78]]])
我们不需要重复,broadcasting
将取代它:
(x*y[:,None,:]).sum(axis=-1, keepdims=True)
此 einsum
与 dot/@
的作用相同:
In [441]: np.einsum('ijk,lk->ijl',x,y)
Out[441]:
array([[[ 6, 12],
[ 9, 18]],
[[36, 72],
[39, 78]]])
稍微更改索引以获得“对角线”(i
所有术语)
In [442]: np.einsum('ijk,ik->ij',x,y)
Out[442]:
array([[ 6, 9],
[72, 78]])
并添加尾随维度:
In [443]: np.einsum('ijk,ik->ij',x,y)[:,:,None]
Out[443]:
array([[[ 6],
[ 9]],
[[72],
[78]]])
现在我有了 einsum
,我可以想象 matmul/@
维度。我需要将两者的第一个维度视为 'batch',并向 y
添加一个新的尾随维度,使其成为 (2,3,1)。 (2,2,3) with (2,3,1) => (2,2,1) with the sum-of-products on the 3.
In [445]: x@y[:,:,None]
Out[445]:
array([[[ 6],
[ 9]],
[[72],
[78]]])
如果x
和y
是(4,2,3)和(4,3)形,这个维度匹配会更明显
In [446]: X=x.repeat(2,0)
In [447]: Y=y.repeat(2,0)
In [448]: X.shape
Out[448]: (4, 2, 3)
In [449]: Y.shape
Out[449]: (4, 3)
In [450]: X@Y[:,:,None] # (4,2,1)
Out[450]:
array([[[ 6],
[ 9]],
[[ 6],
[ 9]],
[[72],
[78]],
[[72],
[78]]])
有了这些形状,更明显的是 4
是批次,3
是乘积之和。
我正在尝试找到一个 numpy 运算,该运算给出 3d 数组索引 i 处的 2d 数组的所有向量与 2d 数组索引 i 处的向量之间的标量积。让我举个例子来解释我的想法:
x = np.array([[[1,2,3],
[2,3,4]],
[[11,12,13],
[12,13,14]]])
y = np.array([[1,1,1],
[2,2,2]])
np.?operation?(x,y.T)
output:
[[[1 *1 + 1 *2 + 1 *3],
[1 *2 + 1 *3 + 1 *4]],
[[2 *11 + 2 *12 + 2 *13],
[2 *12 + 2* 13 + 2 *14]]]
= [[[6],
[9]],
[[72],
[78]]]
如您所见,我基本上是在寻找简化的点积运算。 x 和 y 的点积将产生以下结果:
np.dot(x, y.T)
output:
[[[ 6 12]
[ 9 18]]
[[36 72]
[39 78]]]
或者有没有办法从点积结果中提取我需要的结果?
我也试过 np.tensordot(x,y,axis) 但我无法确定我应该为 -axis- 放置哪些元组。 我也遇到过 np.einsum() 操作,但无法理解这如何帮助我解决问题。
它应该可以用 np.einsum
或 np.matmul/@
来实现,它们在前导维度上有一个“批处理”操作。但是整理尺寸并获得 (2,2,1) 形状有点棘手。
你的 np.dot(x, y.T)
给出了你想要的数字,但你必须在 2 上提取一种对角线,同时保留一个维度。
这是执行此操作的一种方法 - 它不是最快或最简洁的,但应该可以帮助我了解尺寸。
In [432]: y[:,None,:]
Out[432]:
array([[[1, 1, 1]],
[[2, 2, 2]]])
In [433]: y[:,None,:].repeat(2,1)
Out[433]:
array([[[1, 1, 1],
[1, 1, 1]],
[[2, 2, 2],
[2, 2, 2]]])
In [435]: x*y[:,None,:].repeat(2,1)
Out[435]:
array([[[ 1, 2, 3],
[ 2, 3, 4]],
[[22, 24, 26],
[24, 26, 28]]])
In [436]: (x*y[:,None,:].repeat(2,1)).sum(axis=-1, keepdims=True)
Out[436]:
array([[[ 6],
[ 9]],
[[72],
[78]]])
我们不需要重复,broadcasting
将取代它:
(x*y[:,None,:]).sum(axis=-1, keepdims=True)
此 einsum
与 dot/@
的作用相同:
In [441]: np.einsum('ijk,lk->ijl',x,y)
Out[441]:
array([[[ 6, 12],
[ 9, 18]],
[[36, 72],
[39, 78]]])
稍微更改索引以获得“对角线”(i
所有术语)
In [442]: np.einsum('ijk,ik->ij',x,y)
Out[442]:
array([[ 6, 9],
[72, 78]])
并添加尾随维度:
In [443]: np.einsum('ijk,ik->ij',x,y)[:,:,None]
Out[443]:
array([[[ 6],
[ 9]],
[[72],
[78]]])
现在我有了 einsum
,我可以想象 matmul/@
维度。我需要将两者的第一个维度视为 'batch',并向 y
添加一个新的尾随维度,使其成为 (2,3,1)。 (2,2,3) with (2,3,1) => (2,2,1) with the sum-of-products on the 3.
In [445]: x@y[:,:,None]
Out[445]:
array([[[ 6],
[ 9]],
[[72],
[78]]])
如果x
和y
是(4,2,3)和(4,3)形,这个维度匹配会更明显
In [446]: X=x.repeat(2,0)
In [447]: Y=y.repeat(2,0)
In [448]: X.shape
Out[448]: (4, 2, 3)
In [449]: Y.shape
Out[449]: (4, 3)
In [450]: X@Y[:,:,None] # (4,2,1)
Out[450]:
array([[[ 6],
[ 9]],
[[ 6],
[ 9]],
[[72],
[78]],
[[72],
[78]]])
有了这些形状,更明显的是 4
是批次,3
是乘积之和。