比 "for loop" 更快的方法 - 通过 2d 数组索引 3d numpy 数组
A faster way than "for loop" - Indexing 3d numpy array by 2d array
我有一个 3d (3, 2, 3)
数组,第一维 (3) 是灵活的,它可以是任何大小。
arr = np.array(
[[[56, 24, 32],
[56, 24, 32]],
[[51, 27, 72],
[51, 27, 72]],
[[36, 14, 49],
[36, 14, 49]]])
索引数组是(2,3)
:
idxs = np.array(
[[1, 0, 2],
[2, 1, 0]])
我想用 idxs 索引 arr。预期结果是:
[[[24 56 32]
[32 24 56]]
[[27 51 72]
[72 27 51]]
[[14 36 49]
[49 14 36]]])
如果我像下面这样使用 for 循环会花费很多时间:
for i, arr2d in enumerate(arr):
for j, (arr1d, idx) in enumerate(zip(arr2d, idxs)):
arr[i, j] = arr1d[idx]
所以我的问题是:我怎样才能加快这个过程?
使用np.take_along_axis
np.take_along_axis(arr, idxs[None, ...], 2)
Out[]:
array([[[24, 56, 32],
[32, 24, 56]],
[[27, 51, 72],
[72, 27, 51]],
[[14, 36, 49],
[49, 14, 36]]])
我有一个 3d (3, 2, 3)
数组,第一维 (3) 是灵活的,它可以是任何大小。
arr = np.array(
[[[56, 24, 32],
[56, 24, 32]],
[[51, 27, 72],
[51, 27, 72]],
[[36, 14, 49],
[36, 14, 49]]])
索引数组是(2,3)
:
idxs = np.array(
[[1, 0, 2],
[2, 1, 0]])
我想用 idxs 索引 arr。预期结果是:
[[[24 56 32]
[32 24 56]]
[[27 51 72]
[72 27 51]]
[[14 36 49]
[49 14 36]]])
如果我像下面这样使用 for 循环会花费很多时间:
for i, arr2d in enumerate(arr):
for j, (arr1d, idx) in enumerate(zip(arr2d, idxs)):
arr[i, j] = arr1d[idx]
所以我的问题是:我怎样才能加快这个过程?
使用np.take_along_axis
np.take_along_axis(arr, idxs[None, ...], 2)
Out[]:
array([[[24, 56, 32],
[32, 24, 56]],
[[27, 51, 72],
[72, 27, 51]],
[[14, 36, 49],
[49, 14, 36]]])