定义 numpy 索引数组
Defining numpy indexing arrays
我对 numpy 索引有点困惑。假设我有一个三维数组,例如:
test_arr = np.arange(3*2*3).reshape(3,2,3)
test_arr
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]])
我想通过维度 1 的布尔数组对其进行索引:
dim1_idx = np.array([True, False])
test_arr[:, dim1_idx, :]
这给了我
array([[[ 0, 1, 2]],
[[ 6, 7, 8]],
[[12, 13, 14]]])
到目前为止一切都很好。
我的问题是,有没有一种方法可以让我提前定义这个布尔索引数组——比如(这行不通):
all_dim_idx = dim1_idx[np.newaxis, :, np.newaxis]
test_arr[all_dim_idx]
我意识到这不是的原因是因为它无法以某种方式广播以使 all_dim_idx 数组适合 test_arr。我可以使用 np.tile 或 np.reshape 使索引数组适合更大的数组,但是(并且不能再推广到其他数组形状)我只是觉得可能有更好的方法.谁能赐教一下?
提前致谢!
In [600]: test_arr = np.arange(3*2*3).reshape(3,2,3)
In [601]: dim1_idx = np.array([True, False])
定义一个索引元组:
In [602]: idx = (slice(None), dim1_idx, slice(None))
In [603]: test_arr[idx]
Out[603]:
array([[[ 0, 1, 2]],
[[ 6, 7, 8]],
[[12, 13, 14]]])
我对 numpy 索引有点困惑。假设我有一个三维数组,例如:
test_arr = np.arange(3*2*3).reshape(3,2,3)
test_arr
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]])
我想通过维度 1 的布尔数组对其进行索引:
dim1_idx = np.array([True, False])
test_arr[:, dim1_idx, :]
这给了我
array([[[ 0, 1, 2]],
[[ 6, 7, 8]],
[[12, 13, 14]]])
到目前为止一切都很好。
我的问题是,有没有一种方法可以让我提前定义这个布尔索引数组——比如(这行不通):
all_dim_idx = dim1_idx[np.newaxis, :, np.newaxis]
test_arr[all_dim_idx]
我意识到这不是的原因是因为它无法以某种方式广播以使 all_dim_idx 数组适合 test_arr。我可以使用 np.tile 或 np.reshape 使索引数组适合更大的数组,但是(并且不能再推广到其他数组形状)我只是觉得可能有更好的方法.谁能赐教一下?
提前致谢!
In [600]: test_arr = np.arange(3*2*3).reshape(3,2,3)
In [601]: dim1_idx = np.array([True, False])
定义一个索引元组:
In [602]: idx = (slice(None), dim1_idx, slice(None))
In [603]: test_arr[idx]
Out[603]:
array([[[ 0, 1, 2]],
[[ 6, 7, 8]],
[[12, 13, 14]]])