没有重复的 Numpy 广播示例

Numpy Broadcasting Example without Repeat

我正在尝试找出一种不使用 np.repeat 创建大维度来执行以下加法运算的好方法。如果使用 np.repeat 并添加是最好的解决方案,请告诉我。

我也对广播在这种情况下的作用感到困惑。本质上我有一个 4d 矩阵,我想在第一个和第二个索引中添加一个 2d 矩阵,同时在索引 0 和索引 3 中执行此操作。

这可以正常工作

a = np.arange(64).reshape((2,4,4,2)).astype(float)
b = np.ones((2,2))
a[:, 0:2, 0:2, : ] += b

这会引发错误。这样做的好方法是什么?

a[:, 0:3, 0:3, :] += np.ones((3,3))

这可行,但不是我想要做的

c = np.arange(144).reshape(3,4,4,3).astype(float)
c[:, 0:3, 0:3, :] += np.ones((3,3))

要对齐要添加的数组的轴,我们需要在末尾插入一个新轴,像这样 -

a[:, 0:3, 0:3, :] += np.ones((3,3))[...,None]

我们来研究一下这里的形状:

In [356]: a[:, 0:3, 0:3, :].shape
Out[356]: (2, 3, 3, 2)

In [357]: np.ones((3,3)).shape
Out[357]: (3, 3)

In [358]: np.ones((3,3))[...,None].shape
Out[358]: (3, 3, 1)


Input1 (a[:, 0:3, 0:3, :])        :     (2, 3, 3, 2) 
Input2 (np.ones((3,3))[...,None]) :        (3, 3, 1)

请记住,广播规则规定单维度(具有 lengths = 1 的维度)将广播以匹配其他 non-singleton 维度的长度。此外,未列出的尺寸实际上默认长度为 1

所以,这是可广播的,现在可以工作了。


第 2 部分:为什么以下内容有效?

c = np.arange(144).reshape(3,4,4,3).astype(float)
c[:, 0:3, 0:3, :] += np.ones((3,3))

再次学习形状-

In [363]: c[:, 0:3, 0:3, :].shape
Out[363]: (3, 3, 3, 3)

In [364]: np.ones((3,3)).shape
Out[364]: (3, 3)

Input1 (c[:, 0:3, 0:3, :])  :     (3, 3, 3, 3) 
Input2 (np.ones((3,3)))     :           (3, 3)

再次按照可广播规则进行,这没问题,所以这里没有错误,但结果不是预期的。

您可以从一开始就包含一个空轴:

a[:, 0:3, 0:3, :] += np.ones((3,3,1))  # 1 broadcasts against any axis

类似的你应该用过:

a[:, 0:2, 0:2, : ] += np.ones((2,2,1))

因为你(可能无意中)对着第三轴和第四轴广播了这些。我想你想让它广播给第二个和第三个,对吧?


此外,您始终可以使用 np.expand_dimsaxis=-1 添加维度:

>>> np.expand_dims(np.ones((2, 2)), axis=-1).shape
(2, 2, 1)

或用 Nonenp.newaxis 切片(它们是等价的!):

>>> np.ones((2, 2))[None, :, :, np.newaxis].shape
(1, 2, 2, 1)

第一个 None 不是正确广播所必需的,但最后一个是!


在这种情况下,重要的是要提到从最后一个维度开始的 numpy 广播。因此,如果您有两个数组,从最后一个开始的每个维度必须具有相同的形状,或者其中一个必须为 1(如果一个为 1,则它会沿该轴传播!)。这就是 a[:, 0:2, 0:2, : ] 起作用的原因:

>>> a[:, 0:2, 0:2, : ].shape
(2, 2, 2, 2)
>>> b.shape
(2, 2)

所以最后一个维度是相等的(2)并且 second-last 是相等的(2)。然而:

>>> np.ones((2,2,1)).shape
(2, 2, 1)

最后一个是21所以np.ones((2,2,1))的最后一个轴是广播的,而第二个和第三个维度是相等的(所有2)所以numpy在那里使用 element-wise 操作。