从具有索引列表的多维数组中选择

selecting from a multi-dimesional array with a list of indices

假设我有一个大小为 batch x max_len x output_size 的数组,其中 batchmax_lenoutput_size都对应正自然数。我有一个索引列表,对应于维度 1 中的各个项目(即 max_len)。给定这些索引,我如何从数组中 select?

举个具体的例子,假设我有以下内容:

>>> l = np.random.randn(4,5,6)
>>> l.shape
(4, 5, 6)
>>> idx = [0,0,2,3]

当我 select l 给出 idx 我得到:

>>> l[:,idx,:].shape
(4, 4, 6)
>>>

我也尝试了 np.take 但得到了相同的结果:

>>> np.take(l,idx,axis=1).shape
(4, 4, 6)
>>> 

但是,我要查看的输出是 (4,1,6),因为我试图让一个项目查看 batch 中的每个元素(即第一维)。如何生成具有正确形状的输出?

在扩展 idx 后使用 np.take_along_axis 以获得与 l -

相同的 ndim
np.take_along_axis(l,np.asarray(idx)[:,None,None],axis=1)

使用显式整数数组索引 -

l[np.arange(len(idx)),idx][:,None] # skip [:,None] for (4,6) shaped o/p