如果嵌套数组的最大值高于阈值,则获取嵌套数组的 Numpy 条件

Numpy condition for getting nested arrays if their max is above a threshold

我有以下数组:

arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])

我想获取 arr 索引 ,其中包含一个数组,其最大值大于或等于 .9。因此,对于这种情况,结果将是 [1] 因为索引为 1 [.9, .1] 的数组是唯一一个最大值 >= 9.

我试过了:

>>> condition = np.max(arr) >= .9
>>> arr[condition]
array([ 0.5,  0.5])

但是,如您所见,它给出了错误的答案。

沿轴使用max获取行最大值,然后where获取最大的索引:

np.where(arr.max(axis=1)>=0.9)

我想你想要 np.where 这里。此函数 returns 满足特定条件的任何值的索引:

>>> np.where(arr >= 0.9)[0] # here we look at the whole 2D array
array([1])

(np.where(arr >= 0.9) returns 索引数组的元组,数组的每个轴一个。您的预期输出意味着您只需要行索引(轴 0)。)

如果想先取每一行的最大值,可以使用arr.max(axis=1):

>>> np.where(arr.max(axis=1) >= 0.9)[0] # here we look at the 1D array of row maximums
array([1])

你得到错误答案的原因是因为 np.max(arr) 给了你展平数组的最大值。你想要 np.max(arr, axis=1) 或者更好的是 arr.max(axis=1).

(arr.max(axis=1)>=.9).nonzero()
In [18]: arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])

In [19]: numpy.argwhere(numpy.max(arr, 1) >= 0.9)
Out[19]: array([[1]])