定义 numpy 索引数组

Defining numpy indexing arrays

我对 numpy 索引有点困惑。假设我有一个三维数组,例如:

test_arr = np.arange(3*2*3).reshape(3,2,3)
test_arr
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]]])

我想通过维度 1 的布尔数组对其进行索引:

dim1_idx = np.array([True, False])
test_arr[:, dim1_idx, :]

这给了我

array([[[ 0,  1,  2]],

       [[ 6,  7,  8]],

       [[12, 13, 14]]])

到目前为止一切都很好。

我的问题是,有没有一种方法可以让我提前定义这个布尔索引数组——比如(这行不通):

all_dim_idx = dim1_idx[np.newaxis, :, np.newaxis]
test_arr[all_dim_idx]

我意识到这不是的原因是因为它无法以某种方式广播以使 all_dim_idx 数组适合 test_arr。我可以使用 np.tile 或 np.reshape 使索引数组适合更大的数组,但是(并且不能再推广到其他数组形状)我只是觉得可能有更好的方法.谁能赐教一下?

提前致谢!

In [600]: test_arr = np.arange(3*2*3).reshape(3,2,3)                            
In [601]: dim1_idx = np.array([True, False])                                    

定义一个索引元组:

In [602]: idx = (slice(None), dim1_idx, slice(None))                            
In [603]: test_arr[idx]                                                         
Out[603]: 
array([[[ 0,  1,  2]],

       [[ 6,  7,  8]],

       [[12, 13, 14]]])