numpy 中是否有简化形式的点积?

Is there a reduced form of the dot product in numpy?

我正在尝试找到一个 numpy 运算,该运算给出 3d 数组索引 i 处的 2d 数组的所有向量与 2d 数组索引 i 处的向量之间的标量积。让我举个例子来解释我的想法:

x = np.array([[[1,2,3],
              [2,3,4]],
             
             [[11,12,13],
              [12,13,14]]])

y = np.array([[1,1,1],
              [2,2,2]])



np.?operation?(x,y.T)

output:
[[[1 *1 + 1 *2 + 1 *3], 
  [1 *2 + 1 *3 + 1 *4]],

 [[2 *11 + 2 *12 + 2 *13], 
  [2 *12 + 2* 13 + 2 *14]]]

= [[[6], 
    [9]],

   [[72], 
    [78]]]

如您所见,我基本上是在寻找简化的点积运算。 x 和 y 的点积将产生以下结果:

np.dot(x, y.T)

output:
[[[ 6 12]
  [ 9 18]]

 [[36 72]
  [39 78]]]

或者有没有办法从点积结果中提取我需要的结果?

我也试过 np.tensordot(x,y,axis) 但我无法确定我应该为 -axis- 放置哪些元组。 我也遇到过 np.einsum() 操作,但无法理解这如何帮助我解决问题。

它应该可以用 np.einsumnp.matmul/@ 来实现,它们在前导维度上有一个“批处理”操作。但是整理尺寸并获得 (2,2,1) 形状有点棘手。

你的 np.dot(x, y.T) 给出了你想要的数字,但你必须在 2 上提取一种对角线,同时保留一个维度。

这是执行此操作的一种方法 - 它不是最快或最简洁的,但应该可以帮助我了解尺寸。

In [432]: y[:,None,:]
Out[432]: 
array([[[1, 1, 1]],

       [[2, 2, 2]]])
In [433]: y[:,None,:].repeat(2,1)
Out[433]: 
array([[[1, 1, 1],
        [1, 1, 1]],

       [[2, 2, 2],
        [2, 2, 2]]])
In [435]: x*y[:,None,:].repeat(2,1)
Out[435]: 
array([[[ 1,  2,  3],
        [ 2,  3,  4]],

       [[22, 24, 26],
        [24, 26, 28]]])
In [436]: (x*y[:,None,:].repeat(2,1)).sum(axis=-1, keepdims=True)
Out[436]: 
array([[[ 6],
        [ 9]],

       [[72],
        [78]]])

我们不需要重复,broadcasting 将取代它:

(x*y[:,None,:]).sum(axis=-1, keepdims=True)

einsumdot/@ 的作用相同:

In [441]: np.einsum('ijk,lk->ijl',x,y)
Out[441]: 
array([[[ 6, 12],
        [ 9, 18]],

       [[36, 72],
        [39, 78]]])

稍微更改索引以获得“对角线”(i 所有术语)

In [442]: np.einsum('ijk,ik->ij',x,y)
Out[442]: 
array([[ 6,  9],
       [72, 78]])

并添加尾随维度:

In [443]: np.einsum('ijk,ik->ij',x,y)[:,:,None]
Out[443]: 
array([[[ 6],
        [ 9]],

       [[72],
        [78]]])

现在我有了 einsum,我可以想象 matmul/@ 维度。我需要将两者的第一个维度视为 'batch',并向 y 添加一个新的尾随维度,使其成为 (2,3,1)。 (2,2,3) with (2,3,1) => (2,2,1) with the sum-of-products on the 3.

In [445]: x@y[:,:,None]
Out[445]: 
array([[[ 6],
        [ 9]],

       [[72],
        [78]]])

如果xy是(4,2,3)和(4,3)形,这个维度匹配会更明显

In [446]: X=x.repeat(2,0)
In [447]: Y=y.repeat(2,0)
In [448]: X.shape
Out[448]: (4, 2, 3)
In [449]: Y.shape
Out[449]: (4, 3)
In [450]: X@Y[:,:,None]    # (4,2,1)
Out[450]: 
array([[[ 6],
        [ 9]],

       [[ 6],
        [ 9]],

       [[72],
        [78]],

       [[72],
        [78]]])

有了这些形状,更明显的是 4 是批次,3 是乘积之和。