通过 NumPy 高级索引应用操作时出现问题

Problem applying an operation by NumPy advanced indexing

我有两个数组 c = (2, 10, 2, 3)p = (2, 5, 3)。我想选择 c 的 3d 维度中的第一个向量,并从 p 的相应行中的每个向量中减去它。通过这个例子会更好理解,我们可以在第一行用c[0, :, None, 0] - p[0]c[0, :, 0][:, None] - p[0]做这个操作,得到形状为(10, 5, 3)的结果。但是我很困惑如何通过高级索引一次对所有行应用这个操作,结果将是形状 (2, 10, 5, 3).

p = np.array([[[0., -0.05,  0.], [0., -0.05,  0.], [0.,  0.,  0.], [-0.05, -0.1,  0.05], [0.1,  0.2,  0.1]],
              [[0., -0.05,  0.], [0., -0.05,  0.], [0.,  0.,  0.], [-0.05, -0.1,  0.05], [0.1,  0.2,  0.1]]])

c = np.array([[[[ 0.05,  0. ,  0.05], [ 0.05, -0.1,  0.05]], [[ 0.05, -0.1,  0.05], [-0.05, -0.1,  0.05]],
               [[-0.05, -0.1,  0.05], [-0.05,  0. ,  0.05]], [[ 0.05,  0. , -0.05], [ 0.05, -0.1, -0.05]],
               [[ 0.05, -0.1, -0.05], [-0.05, -0.1, -0.05]], [[-0.05, -0.1, -0.05], [-0.05,  0. , -0.05]],
               [[ 0.05,  0. ,  0.05], [ 0.05,  0. , -0.05]], [[ 0.05, -0.1,  0.05], [ 0.05, -0.1, -0.05]],
               [[-0.05, -0.1,  0.05], [-0.05, -0.1, -0.05]], [[-0.05,  0. ,  0.05], [-0.05,  0. , -0.05]]],
              [[[ 0.05,  0.1,  0.05], [ 0.05,  0. ,  0.05]], [[ 0.05,  0. ,  0.05], [-0.05,  0. ,  0.05]],
               [[-0.05,  0. ,  0.05], [-0.05,  0.1,  0.05]], [[ 0.05,  0.1, -0.05], [ 0.05,  0. , -0.05]],
               [[ 0.05,  0. , -0.05], [-0.05,  0. , -0.05]], [[-0.05,  0. , -0.05], [-0.05,  0.1, -0.05]],
               [[ 0.05,  0.1,  0.05], [ 0.05,  0.1, -0.05]], [[ 0.05,  0. ,  0.05], [ 0.05,  0. , -0.05]],
               [[-0.05,  0. ,  0.05], [-0.05,  0. , -0.05]], [[-0.05,  0.1,  0.05], [-0.05,  0.1, -0.05]]]])

补充建议:
这不是主要问题,但对我有帮助,如果有任何替代方法可以解决此问题,我将不胜感激:
我已经将 p 从形状 (5, 3) 广播到 (2, 5, 3) 以对应于 c 的第一个维度。有没有更好的方法,没有这种广播(可能只是通过高级索引代替)来处理这个问题?

让我们从这里开始:

I have broadcasted p from shape (5, 3) to (2, 5, 3) to be corresponded with 1st dimension of c. Is there a better way, without this broadcasting (may be just by advanced indexing instead), to handle this problem?

所以,如果我没理解错的话,p 实际上应该是 (5, 3),你手动将它加倍来制作一个 (2, 5, 3) 数组。正如您将看到的,您不必这样做,实际上坚持使用 a (5, 3) 数组 将使整个事情变得容易得多。

现在当你这样做时:

c[0, :, None, 0]

得到的是一个(10, 1, 3)数组。你在这里走在正确的轨道上,因为 1 的第二个维度是广播所需要的,但是同样,你可以使用

来做更简单的事情
c[:, :, None, 0]

这会给你 a (2, 10, 1, 3) 数组。为了利用 numpy's broadcasting rules:

,您真的非常需要这些东西

When operating on two arrays, NumPy compares their shapes element-wise. It starts with the trailing (i.e. rightmost) dimensions and works its way left. Two dimensions are compatible when: 1. they are equal, or 2. one of them is 1

所以,如果我们有 p.shape == (5, 3) 和 c 转换使得 c.shape == (2, 10, 1, 3),匹配将从左:

  • 最后一个维度将匹配,因为它们都是 3(规则 1)
  • c的倒数第二个为1,可以匹配p的第一个(规则2)
  • 较短的 (p) 将沿着剩余的 c 维度进行匹配

示意图:

 (2, 10, 1, 3)
  |  |   |  |
  |  |   V  V
  |  |  (5, 3)
  |  |   |  |
  V  V   V  V
 (2, 10, 5, 3) 

解释很长,但实际上应用程序将非常简单。那么,让我们来试试吧:

import numpy as np

# let's keep this one at (5, 3)
p = np.array([[0., -0.05,  0.], [0., -0.05,  0.], [0.,  0.,  0.], [-0.05, -0.1,  0.05], [0.1,  0.2,  0.1]])

c = np.array([[[[ 0.05,  0. ,  0.05], [ 0.05, -0.1,  0.05]], [[ 0.05, -0.1,  0.05], [-0.05, -0.1,  0.05]],
               [[-0.05, -0.1,  0.05], [-0.05,  0. ,  0.05]], [[ 0.05,  0. , -0.05], [ 0.05, -0.1, -0.05]],
               [[ 0.05, -0.1, -0.05], [-0.05, -0.1, -0.05]], [[-0.05, -0.1, -0.05], [-0.05,  0. , -0.05]],
               [[ 0.05,  0. ,  0.05], [ 0.05,  0. , -0.05]], [[ 0.05, -0.1,  0.05], [ 0.05, -0.1, -0.05]],
               [[-0.05, -0.1,  0.05], [-0.05, -0.1, -0.05]], [[-0.05,  0. ,  0.05], [-0.05,  0. , -0.05]]],
              [[[ 0.05,  0.1,  0.05], [ 0.05,  0. ,  0.05]], [[ 0.05,  0. ,  0.05], [-0.05,  0. ,  0.05]],
               [[-0.05,  0. ,  0.05], [-0.05,  0.1,  0.05]], [[ 0.05,  0.1, -0.05], [ 0.05,  0. , -0.05]],
               [[ 0.05,  0. , -0.05], [-0.05,  0. , -0.05]], [[-0.05,  0. , -0.05], [-0.05,  0.1, -0.05]],
               [[ 0.05,  0.1,  0.05], [ 0.05,  0.1, -0.05]], [[ 0.05,  0. ,  0.05], [ 0.05,  0. , -0.05]],
               [[-0.05,  0. ,  0.05], [-0.05,  0. , -0.05]], [[-0.05,  0.1,  0.05], [-0.05,  0.1, -0.05]]]])

# now, it's really as simple as that
result = c[:, :, None, 0] - p  # shape: (2, 10, 5, 3)