Numpy:多维索引。逐行无循环

Numpy: Multidimensional index. Row by row with no loop

我有一个名为 A 的 Nx2x2x2 数组。我还有一个名为 B 的 Nx2 数组,它告诉我我感兴趣的 A 的最后两个维度的位置。我目前正在通过使用循环(如以下代码中的 C 中)或使用列表理解(如以下代码中的 D 中)获得 Nx2 数组。我想知道矢量化是否会节省时间,如果可以,如何矢量化此任务。

我目前的矢量化方法(下面代码中的 E)是使用 B 来索引 A 的每个子矩阵,因此它 return 不是我想要的。我希望 E return 与 C 或 D 相同。

输入:

A=np.reshape(np.arange(0,32),(4,2,2,2))
print("A")
print(A)
B=np.concatenate((np.array([0,1,0,1])[:,np.newaxis],np.array([1,1,0,0])[:,np.newaxis]),axis=1)
print("B")
print(B)
C=np.empty(shape=(4,2))
for n in range(0, 4):
    C[n,:]=A[n,:,B[n,0],B[n,1]]
print("C")
print(C)
D = np.array([A[n,:,B[n,0],B[n,1]] for n in range(0, 4)])
print("D")
print(D)
E=A[:,:,B[:,0],B[:,1]]
print("E")
print(E)

输出:

A
[[[[ 0  1]
   [ 2  3]]

  [[ 4  5]
   [ 6  7]]]


 [[[ 8  9]
   [10 11]]

  [[12 13]
   [14 15]]]


 [[[16 17]
   [18 19]]

  [[20 21]
   [22 23]]]


 [[[24 25]
   [26 27]]

  [[28 29]
   [30 31]]]]
B
[[0 1]
 [1 1]
 [0 0]
 [1 0]]
C
[[  1.   5.]
 [ 11.  15.]
 [ 16.  20.]
 [ 26.  30.]]
D
[[ 1  5]
 [11 15]
 [16 20]
 [26 30]]
E
[[[ 1  3  0  2]
  [ 5  7  4  6]]

 [[ 9 11  8 10]
  [13 15 12 14]]

 [[17 19 16 18]
  [21 23 20 22]]

 [[25 27 24 26]
  [29 31 28 30]]]

复杂的切片操作可以像这样以向量化的方式完成-

shp = A.shape
out = A.reshape(shp[0],shp[1],-1)[np.arange(shp[0]),:,B[:,0]*shp[3] + B[:,1]]

您正在使用 B 的第一列和第二列索引输入 4D 数组 A 的第三维和第四维。它的意思是,基本上您正在对 4D 数组进行切片,最后两个维度 融合 在一起。因此,您需要使用 B 获得具有该融合格式的线性索引。当然,在做所有这些之前,您需要使用 A.reshape(shp[0],shp[1],-1).

A 重塑为 3D 数组

验证通用 4D 数组案例的结果 -

In [104]: A = np.random.rand(6,3,4,5)
     ...: B = np.concatenate((np.random.randint(0,4,(6,1)),np.random.randint(0,5,(6,1))),1)
     ...: 

In [105]: C=np.empty(shape=(6,3))
     ...: for n in range(0, 6):
     ...:     C[n,:]=A[n,:,B[n,0],B[n,1]]
     ...:     

In [106]: shp = A.shape
     ...: out = A.reshape(shp[0],shp[1],-1)[np.arange(shp[0]),:,B[:,0]*shp[3] + B[:,1]]
     ...: 

In [107]: np.allclose(C,out)
Out[107]: True