在 numpy 二维数组中按行查找值大于阈值的索引
Find indices of values greater than a threshold by row in a numpy 2darray
我有一个 2darray 如下。我想按数组中的每一行查找高于阈值(例如 0.7)的值的索引。
items= np.array([[1. , 0.40824829, 0.03210806, 0.29488391, 0. ,
0.5 , 0.32444284, 0.57735027, 0. , 0.5 ],
[0.40824829, 1. , 0.57675476, 0.48154341, 0. ,
0.81649658, 0.79471941, 0.70710678, 0.57735027, 0.40824829],
[0.03210806, 0.57675476, 1. , 0.42606683, 0. ,
0. , 0.92713363, 0.834192 , 0. , 0.73848549],
[0.29488391, 0.48154341, 0.42606683, 1. , 0. ,
0.29488391, 0.52620136, 0.51075392, 0.20851441, 0.44232587],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[0.5 , 0.81649658, 0. , 0.29488391, 0. ,
1. , 0.32444284, 0.28867513, 0.70710678, 0. ],
[0.32444284, 0.79471941, 0.92713363, 0.52620136, 0. ,
0.32444284, 1. , 0.93658581, 0.22941573, 0.81110711],
[0.57735027, 0.70710678, 0.834192 , 0.51075392, 0. ,
0.28867513, 0.93658581, 1. , 0. , 0.8660254 ],
[0. , 0.57735027, 0. , 0.20851441, 0. ,
0.70710678, 0.22941573, 0. , 1. , 0. ],
[0.5 , 0.40824829, 0.73848549, 0.44232587, 0. ,
0. , 0.81110711, 0.8660254 , 0. , 1. ]])
indices_items = np.argwhere(items>= 0.7)
这个(indices_items)returns
array([[0, 0],
[1, 1],
[1, 5],
[1, 6],
[1, 7],
[2, 2],
[2, 6],
[2, 7],
[2, 9],
[3, 3],
[5, 1],
[5, 5],
[5, 8],
[6, 1],
[6, 2],
[6, 6],
[6, 7],
[6, 9],
[7, 1],
[7, 2],
[7, 6],
[7, 7],
[7, 9],
[8, 5],
[8, 8],
[9, 2],
[9, 6],
[9, 7],
[9, 9]], dtype=int64)
如何按行获取索引,如下所示?
row0 -> [0] row1-> [0,1,5,6,7] row2-> [2,6,7,9] row3-> [3] row4-> []
#这应该是空列表,因为没有超过阈值的值...
获取行,col 与 np.where
,然后使用 np.searchsorted
获取 row-array 上的间隔索引并使用它们拆分 col-array -
In [38]: r,c = np.where(items>= 0.7)
In [39]: np.split(c,np.searchsorted(r,range(1,items.shape[0])))
Out[39]:
[array([0], dtype=int64),
array([1, 5, 6, 7], dtype=int64),
array([2, 6, 7, 9], dtype=int64),
array([3], dtype=int64),
array([], dtype=int64),
array([1, 5, 8], dtype=int64),
array([1, 2, 6, 7, 9], dtype=int64),
array([1, 2, 6, 7, 9], dtype=int64),
array([5, 8], dtype=int64),
array([2, 6, 7, 9], dtype=int64)]
这在性能方面可能不是最佳的,但如果你真的不关心它应该没问题。
indices_items = []
for l in items:
indices_items.append(np.argwhere(l >= 0.7).flatten().tolist())
indices_items
Out[5]:
[[0],
[1, 5, 6, 7],
[2, 6, 7, 9],
[3],
[],
[1, 5, 8],
[1, 2, 6, 7, 9],
[1, 2, 6, 7, 9],
[5, 8],
[2, 6, 7, 9]]
我有一个 2darray 如下。我想按数组中的每一行查找高于阈值(例如 0.7)的值的索引。
items= np.array([[1. , 0.40824829, 0.03210806, 0.29488391, 0. ,
0.5 , 0.32444284, 0.57735027, 0. , 0.5 ],
[0.40824829, 1. , 0.57675476, 0.48154341, 0. ,
0.81649658, 0.79471941, 0.70710678, 0.57735027, 0.40824829],
[0.03210806, 0.57675476, 1. , 0.42606683, 0. ,
0. , 0.92713363, 0.834192 , 0. , 0.73848549],
[0.29488391, 0.48154341, 0.42606683, 1. , 0. ,
0.29488391, 0.52620136, 0.51075392, 0.20851441, 0.44232587],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[0.5 , 0.81649658, 0. , 0.29488391, 0. ,
1. , 0.32444284, 0.28867513, 0.70710678, 0. ],
[0.32444284, 0.79471941, 0.92713363, 0.52620136, 0. ,
0.32444284, 1. , 0.93658581, 0.22941573, 0.81110711],
[0.57735027, 0.70710678, 0.834192 , 0.51075392, 0. ,
0.28867513, 0.93658581, 1. , 0. , 0.8660254 ],
[0. , 0.57735027, 0. , 0.20851441, 0. ,
0.70710678, 0.22941573, 0. , 1. , 0. ],
[0.5 , 0.40824829, 0.73848549, 0.44232587, 0. ,
0. , 0.81110711, 0.8660254 , 0. , 1. ]])
indices_items = np.argwhere(items>= 0.7)
这个(indices_items)returns
array([[0, 0],
[1, 1],
[1, 5],
[1, 6],
[1, 7],
[2, 2],
[2, 6],
[2, 7],
[2, 9],
[3, 3],
[5, 1],
[5, 5],
[5, 8],
[6, 1],
[6, 2],
[6, 6],
[6, 7],
[6, 9],
[7, 1],
[7, 2],
[7, 6],
[7, 7],
[7, 9],
[8, 5],
[8, 8],
[9, 2],
[9, 6],
[9, 7],
[9, 9]], dtype=int64)
如何按行获取索引,如下所示? row0 -> [0] row1-> [0,1,5,6,7] row2-> [2,6,7,9] row3-> [3] row4-> [] #这应该是空列表,因为没有超过阈值的值...
获取行,col 与 np.where
,然后使用 np.searchsorted
获取 row-array 上的间隔索引并使用它们拆分 col-array -
In [38]: r,c = np.where(items>= 0.7)
In [39]: np.split(c,np.searchsorted(r,range(1,items.shape[0])))
Out[39]:
[array([0], dtype=int64),
array([1, 5, 6, 7], dtype=int64),
array([2, 6, 7, 9], dtype=int64),
array([3], dtype=int64),
array([], dtype=int64),
array([1, 5, 8], dtype=int64),
array([1, 2, 6, 7, 9], dtype=int64),
array([1, 2, 6, 7, 9], dtype=int64),
array([5, 8], dtype=int64),
array([2, 6, 7, 9], dtype=int64)]
这在性能方面可能不是最佳的,但如果你真的不关心它应该没问题。
indices_items = []
for l in items:
indices_items.append(np.argwhere(l >= 0.7).flatten().tolist())
indices_items
Out[5]:
[[0],
[1, 5, 6, 7],
[2, 6, 7, 9],
[3],
[],
[1, 5, 8],
[1, 2, 6, 7, 9],
[1, 2, 6, 7, 9],
[5, 8],
[2, 6, 7, 9]]