根据条件替换ndarray的尾部

Replace tail of ndarray based on condition

我有多维数组。一旦它在最后一个维度中具有临界值,我想变异维度的尾部。

np.random.seed(100)
arr = np.random.uniform(size=100).reshape([2,5,2,5])
# array([[[[ 0.54340494,  0.27836939,  0.42451759,  0.84477613,  0.00471886],
#          [ 0.12156912,  0.67074908,  0.82585276,  0.13670659,  0.57509333]],
#         [[ 0.89132195,  0.20920212,  0.18532822,  0.10837689,  0.21969749],
#          [ 0.97862378,  0.81168315,  0.17194101,  0.81622475,  0.27407375]],
#         [[ 0.43170418,  0.94002982,  0.81764938,  0.33611195,  0.17541045],
#          [ 0.37283205,  0.00568851,  0.25242635,  0.79566251,  0.01525497]],
#         [[ 0.59884338,  0.60380454,  0.10514769,  0.38194344,  0.03647606],
#          [ 0.89041156,  0.98092086,  0.05994199,  0.89054594,  0.5769015 ]],
#         [[ 0.74247969,  0.63018394,  0.58184219,  0.02043913,  0.21002658],
#          [ 0.54468488,  0.76911517,  0.25069523,  0.28589569,  0.85239509]]],
#        [[[ 0.97500649,  0.88485329,  0.35950784,  0.59885895,  0.35479561],
#          [ 0.34019022,  0.17808099,  0.23769421,  0.04486228,  0.50543143]],
#         [[ 0.37625245,  0.5928054 ,  0.62994188,  0.14260031,  0.9338413 ],
#          [ 0.94637988,  0.60229666,  0.38776628,  0.363188  ,  0.20434528]],
#         [[ 0.27676506,  0.24653588,  0.173608  ,  0.96660969,  0.9570126 ],
#          [ 0.59797368,  0.73130075,  0.34038522,  0.0920556 ,  0.46349802]],
#         [[ 0.50869889,  0.08846017,  0.52803522,  0.99215804,  0.39503593],
#          [ 0.33559644,  0.80545054,  0.75434899,  0.31306644,  0.63403668]],
#         [[ 0.54040458,  0.29679375,  0.1107879 ,  0.3126403 ,  0.45697913],
#          [ 0.65894007,  0.25425752,  0.64110126,  0.20012361,  0.65762481]]]])

假设临界值为 0.80。在我们看到高于 0.80 的值后,我们需要改变所有其他值。我们先关注两个"rows"。选择 np.argmax 后代表 [3,2]

where_bigger = np.argmax(arr >= 0.80, axis = 3)
# array([[[3, 2], ## used as example later !!!!!!!!!
#         [0, 0],
#         [1, 0],
#         [0, 0],
#         [0, 4]],
#        [[0, 0],
#         [4, 0],
#         [3, 0],
#         [3, 1],
#         [0, 0]]])

例如,我们首先关注 [3,2] 中索引为 3 的元素(见上面的 !!!!)。一旦我们发现值高于 0.80(此类索引为 3),所有以下值都应替换为 np.na

arr[0,0,0,3] ## 0.84477613 comes as first element in [3,2]
# [ 0.54340494,  0.27836939,  0.42451759,  0.84477613,  np.na]

类似这里,我们关注 [3,2] 中的元素 2 并且需要将以下所有元素设置为 np.na

arr[0,0,1,2] ## 0.82585276 comes as second element in [3,2]
# [ 0.12156912,  0.67074908,  0.82585276,  np.na,  np.na]

最后我们对 argmax 找到的所有元素重复它:

# array([[[[ 0.54340494,  0.27836939,  0.42451759,  0.84477613,  np.na],
#          [ 0.12156912,  0.67074908,  0.82585276,       np.na,  np.na]],
#         [[ 0.89132195,       np.na,       np.na,       np.na,  np.na],
#          [ 0.97862378,       np.na,       np.na,       np.na,  np.na]],
#         [[ 0.43170418,  0.94002982,       np.na,       np.na,  np.na],
# ...

是否可以在不循环的情况下一次调整整个数组?可能可以通过切片来完成。我想使用一些方法,比如 arr[where_bigger:] = np.na,但显然是错误的。到目前为止我无法进一步进步。

最好的选择是某种类型的布尔掩码。您可以将 tail 设为 np.logical_or.accumulate,但这将包括具有阈值的索引。如果要保留第一个实例,则必须填充它。

mask = np.c_[np.zeros(arr.shape[:-1] + (1,), dtype = bool), np.logical_or.accumulate(arr > .8, axis = -1)[...,:-1]]
arr[mask] = np.nan