稀疏矩阵:如果总和低于 X (Scipy),则删除行

Sparse matrix: removal of rows if their sum is lower than X (Scipy)

比方说,我有以下稀疏矩阵:

from scipy.sparse import coo_matrix
m = coo_matrix(([1,1,1,3,2], ([1,2,2,3,4],[1,1,2,3,3])))
print(m.toarray())

>>> array([[0, 0, 0, 0],
>>>       [0, 1, 0, 0],
>>>       [0, 1, 1, 0],
>>>       [0, 0, 0, 3],
>>>       [0, 0, 0, 2]])

我只想保留那些总和大于 1 的行。我认为下面的方法可行。

csr = m.tocsr()
csr[(csr.sum(1) > 1)]

但事实并非如此。相反,我不得不对 numpy 数组进行转换(使用 squeeze):

csr = m.tocsr()
csr = csr[np.asarray(csr.sum(1) > 1).squeeze()]
csr.toarray()

所以,我得到了我想要的:

array([[0, 1, 1, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 2]], dtype=int64)

有没有更直接的方法?

我知道有类似的答案 在检查了一些其他答案后,如 this one,但在他们的情况下(使用 M.getnnz(1)>0),函数 returns 直接是一个数组。

查看详情:

In [803]: m = sparse.csr_matrix(([1,1,1,3,2], ([1,2,2,3,4],[1,1,2,3,3])))                              
In [804]: m                                                                                            
Out[804]: 
<5x4 sparse matrix of type '<class 'numpy.longlong'>'
    with 5 stored elements in Compressed Sparse Row format>
In [805]: m.A                                                                                          
Out[805]: 
array([[0, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 1, 1, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 2]], dtype=int64)
In [806]: m.sum(axis=1)                                                                                
Out[806]: 
matrix([[0],
        [1],
        [2],
        [3],
        [2]])

sum on ndarray 减小尺寸(除非设置了 keepdims)。但是 sparse 复制 np.matrix,并保留维度。所以结果是一个 (5,1) 矩阵。

np.matrix 有一个 array/ravel 步骤的缩写:

In [807]: m.sum(axis=1).A1                                                                             
Out[807]: array([0, 1, 2, 3, 2])

和索引:

In [811]: m[m.sum(axis=1).A1>1,:]                                                                      
Out[811]: 
<3x4 sparse matrix of type '<class 'numpy.longlong'>'
    with 4 stored elements in Compressed Sparse Row format>
In [812]: _.A                                                                                          
Out[812]: 
array([[0, 1, 1, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 2]], dtype=int64)

我在其他地方提到过 csr 矩阵索引(通常)使用 'extractor matrix' 和矩阵乘法。考虑到数据的存储方式,这是稳健且合理的,但它不如密集数组索引那么快或强大。

有时我们通过作用于矩阵的基础属性 dataindicesindptr 来提高速度。但那需要对那个表示有更多的理解,这里就不赘述了。