如何索引由某个轴上的索引给出的多维数组?
How to indexing multi-dimensional arrays given by indices in a certain axis?
假设我有一个 4d 数组 A
,形状为 (D0, D1, D2, D3)
。我有一个形状为 (D0,)
的一维数组 B
,其中包括我在轴 2 处需要的索引。
实现我需要的简单方法:
output_lis = []
for i in range(D0):
output_lis.append(A[i, :, B[i], :])
#output = np.concatenate(output_lis, axis=0) #it is wrong to use concatenate. Thanks to @Mad Physicist. Instead, using stack.
output = np.stack(output_lis, axis=0) #shape: [D0, D1, D3]
所以,我的问题是如何用 numpy API 快速实现它?
使用花哨的索引在两个维度上步调一致。在这种情况下,arange
提供序列 i
,而 B
提供序列 B[i]
:
A[np.arange(D0), :, B, :]
这个数组的形状确实是 (D0, D1, D3)
,不像你的 for
循环结果的形状。
要从您的示例中获得相同的结果,请使用 stack
(添加新轴),而不是 concatenate
(使用现有轴):
output = np.stack(output_lis, axis=0)
假设我有一个 4d 数组 A
,形状为 (D0, D1, D2, D3)
。我有一个形状为 (D0,)
的一维数组 B
,其中包括我在轴 2 处需要的索引。
实现我需要的简单方法:
output_lis = []
for i in range(D0):
output_lis.append(A[i, :, B[i], :])
#output = np.concatenate(output_lis, axis=0) #it is wrong to use concatenate. Thanks to @Mad Physicist. Instead, using stack.
output = np.stack(output_lis, axis=0) #shape: [D0, D1, D3]
所以,我的问题是如何用 numpy API 快速实现它?
使用花哨的索引在两个维度上步调一致。在这种情况下,arange
提供序列 i
,而 B
提供序列 B[i]
:
A[np.arange(D0), :, B, :]
这个数组的形状确实是 (D0, D1, D3)
,不像你的 for
循环结果的形状。
要从您的示例中获得相同的结果,请使用 stack
(添加新轴),而不是 concatenate
(使用现有轴):
output = np.stack(output_lis, axis=0)