使用蒙版切片二维数组

Slice 2D array using mask

假设数组

0 = {ndarray: (4,)} [5 0 3 3]
1 = {ndarray: (4,)} [7 9 3 5]
2 = {ndarray: (4,)} [2 4 7 6]
3 = {ndarray: (4,)} [8 8 1 6]

我想要 epoch_label 等于零的切片索引

[1 1 0 0]

从上面看,索引将是第二个和第三个索引

使用masked_where,这将产生

[1 1 -- --]

而且,预期的输出应该是

[2 4 7 6]
[8 8 1 6]

但是,使用下面的代码

epoch_com = [np.random.randint(10, size=4) for _ in range(Nepochs)]
epoch_com_arr=np.array(epoch_com)
epoch_label=np.random.randint(2, size=Nepochs)
mm=ma.masked_where(epoch_label == 0, epoch_label)
expected_output=np.where(epoch_com_arr[mm,:])

以上片段代码产生

0 = {ndarray: (14,)} [0 0 0 0 1 1 1 1 2 2 2 3 3 3]
1 = {ndarray: (14,)} [0 1 2 3 0 1 2 3 0 2 3 0 2 3]

这不是我想要的

expected_output=epoch_com_arr[mm,:]

产生了

0 = {ndarray: (4,)} [7 9 3 5]
1 = {ndarray: (4,)} [7 9 3 5]
2 = {ndarray: (4,)} [5 0 3 3]
3 = {ndarray: (4,)} [5 0 3 3]

请问如何解决

 In [242]: Nepochs = 4
 ...: epoch_com = [np.random.randint(10, size=4) for _ in range(Nepochs)]
 ...: epoch_com_arr=np.array(epoch_com)
 ...: epoch_label=np.random.randint(2, size=Nepochs)
 ...: mm=np.ma.masked_where(epoch_label == 0, epoch_label)
 ...: expected_output=np.where(epoch_com_arr[mm,:])

查看变量:

In [246]: epoch_com_arr       # a (4,4) array
Out[246]: 
array([[7, 1, 3, 3],
       [5, 6, 7, 8],
       [5, 6, 3, 8],
       [3, 5, 1, 1]])

我不知道你为什么要使用“0 = {ndarray: (4,)} [5 0 3 3]”的显示方式。不正常numpy.

我认为制作 masked_array 没有任何好处:

In [247]: epoch_label
Out[247]: array([0, 0, 1, 0])
In [248]: mm
Out[248]: 
masked_array(data=[--, --, 1, --],
             mask=[ True,  True, False,  True],
       fill_value=999999)

而是将 0/1 转换为布尔值。通常当我们谈论 'masking' 时,我们的意思是使用布尔数组作为索引,而不是使用 np.ma.

In [249]: epoch_label.astype(bool)
Out[249]: array([False, False,  True, False])

该布尔值可用于 select 行 arr,或者 'deselect' 它们:

In [250]: epoch_com_arr[epoch_label.astype(bool),:]
Out[250]: array([[5, 6, 3, 8]])
In [251]: epoch_com_arr[~epoch_label.astype(bool),:]
Out[251]: 
array([[7, 1, 3, 3],
       [5, 6, 7, 8],
       [3, 5, 1, 1]])

我认为 np.where 在这里没有用。这给出了 epoch_com_arr[mm,:] 中非零项的索引,并且使用 np.ma` 数组进行索引是有问题的。

np.where 可用于将 epoch_label 转换为索引:

In [252]: idx = np.nonzero(epoch_label)   # aka np.where
In [253]: idx
Out[253]: (array([2]),)
In [254]: epoch_com_arr[idx,:]
Out[254]: array([[[5, 6, 3, 8]]])