通过 python 中的值找到稀疏二维矩阵的 y 索引

finding y index of a sparse 2D matrix by its value in python

我有一个二维稀疏矩阵 "unknown_tfidf",大小为 (1000,10000),类型为:

<class 'scipy.sparse.csr.csr_matrix'>

我需要获取此矩阵的 y 索引,其中值为 '1',我正在尝试以下方法(不确定它是否是最佳方法甚至是正确方法!)但我遇到了一个错误:

y=[row.index(1.0) for index, row in enumerate(unknown_tfidf) if int(1.0) in row]

错误是:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

我的问题是如何才能只得到矩阵值为 1 的矩阵的所有 y 索引?

压缩稀疏行 (CSR) 矩阵等于 1 的 的索引存储在其 .indices 属性中:

import numpy as np
import scipy.sparse as sparse
np.random.seed(2016)

arr = np.round(10*sparse.rand(10, 10, density=0.8, format='csr'))
# arr.A
# array([[  5.,   0.,   7.,   7.,   8.,   7.,   0.,   2.,   4.,   2.],
#        [  4.,   0.,   9.,   2.,   4.,   8.,   4.,   2.,   5.,   9.],
#        [  7.,   4.,   4.,   2.,   4.,   0.,   0.,   0.,   6.,   0.],
#        [  8.,   0.,   0.,   7.,   0.,   6.,   5.,   8.,   0.,   3.],
#        [  3.,   5.,   1.,   0.,   0.,   7.,   3.,   8.,   3.,   0.],
#        [  8.,   6.,   7.,   0.,   8.,   2.,   7.,   0.,   1.,   1.],
#        [  4.,   6.,   3.,   1.,   8.,   7.,   8.,   6.,   0.,   2.],
#        [  7.,   7.,   0.,  10.,   6.,   2.,   4.,   2.,   1.,  10.],
#        [ 10.,   0.,   4.,   8.,   1.,   1.,   3.,   1.,   9.,   1.],
#        [  0.,   4.,   0.,   0.,   7.,   2.,  10.,   1.,   9.,   0.]])

condition = (arr == 1)
print(condition.indices)

产量

[2 8 9 3 8 4 5 7 9 7]

The fastest way求出arr等于1的行索引和列索引,就是把arr转成COO矩阵,然后读出它的rowcol 属性:

coo = condition.tocoo()
print(coo.row)
print(coo.col)

产量

[4 5 5 6 7 8 8 8 8 9]
[2 8 9 3 8 4 5 7 9 7]

您的列表理解适用于嵌套列表

In [100]: xl=[[0,1,3],[0,0,1],[1,1,0]]
In [101]: [row.index(1) for index, row in enumerate(xl) if 1 in row]
Out[101]: [1, 2, 0]

(注意 index returns 只是第三行的第一个匹配)。

但不适用于 numpy.array:

In [102]: xa=np.array(xl)
In [103]: [row.index(1) for index, row in enumerate(xa) if 1 in row]
...
AttributeError: 'numpy.ndarray' object has no attribute 'index'

而不是稀疏矩阵:

In [104]: xs=sparse.csr_matrix(xl)
In [105]: xs
Out[105]: 
<3x3 sparse matrix of type '<class 'numpy.int32'>'
    with 5 stored elements in Compressed Sparse Row format>
In [106]: [row.index(1) for index, row in enumerate(xs) if 1 in row]
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

如果我删除 if 测试,我会得到一个不同的错误,一个密集数组错误的变体。

In [108]: [row.index(1) for index, row in enumerate(xs)]
...
AttributeError: index not found

看看枚举给我们带来了什么;

In [109]: [(index,row) for index, row in enumerate(xs)]
Out[109]: 
[(0, <1x3 sparse matrix of type '<class 'numpy.int32'>'
    with 2 stored elements in Compressed Sparse Row format>),
 (1, <1x3 sparse matrix of type '<class 'numpy.int32'>'
    with 1 stored elements in Compressed Sparse Row format>),
 (2, <1x3 sparse matrix of type '<class 'numpy.int32'>'
    with 2 stored elements in Compressed Sparse Row format>)]

row 是另一个稀疏矩阵,与 xs[0] 等相同。因此 1 in rowrow.index(1) 表达式必须与数组或矩阵一起使用,否则你会得到一个错误。

我们已经看到 index 方法也没有。这是一个列表方法——你必须对数组或稀疏矩阵使用其他方法。您的理解有 if 子句,因为如果找不到该项目,列表 index 会引发错误。从这个意义上说,if inindex 并存。

in 适用于数组,但给出稀疏矩阵的值错误:

In [114]: 1 in xa[0]
Out[114]: True
In [115]: 1 in xs[0]
....
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

更常见的是,此 ValueError 由以下等价物产生:

In [117]: if np.array([True, False, True]):'yes'
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

也就是给一个if一个布尔数组。在您的情况下,此故障发生在 sparse 代码中。实际上 in 还没有为稀疏实现。

因此,如果您坚持使用这种列表理解方法,则必须将稀疏矩阵转换为列表列表:

In [120]: [row.index(1) for index, row in enumerate(xs.toarray().tolist()) if 1 in row]
Out[120]: [1, 2, 0]

这是 unutbu's 答案的变体:

使用 matrix/array 相等性测试找到所有匹配的元素:

In [121]: xs==1
Out[121]: 
<3x3 sparse matrix of type '<class 'numpy.bool_'>'
    with 4 stored elements in Compressed Sparse Row format>
In [122]: (xs==1).A
Out[122]: 
array([[False,  True, False],
       [False, False,  True],
       [ True,  True, False]], dtype=bool)

然后使用内置方法获取那些 True 元素的索引:

In [123]: (xs==1).nonzero()
Out[123]: (array([0, 1, 2, 2], dtype=int32), array([1, 2, 0, 1], dtype=int32))

该元组的第二个元素是您想要的列表(第 3 行有 2 个值)。

或者收集行的值(记住,在迭代中每一行都是一个矩阵)

In [125]: [i.nonzero() for i in (xs==1)]
Out[125]: 
[(array([0], dtype=int32), array([1], dtype=int32)),
 (array([0], dtype=int32), array([2], dtype=int32)),
 (array([0, 0], dtype=int32), array([0, 1], dtype=int32))]

将该列表缩减为简单的索引列表需要更多操作

In [131]: [i.nonzero()[1].tolist() for i in (xs==1)]
Out[131]: [[1], [2], [0, 1]]