如何获取稀疏矩阵数据数组的对角线元素的索引

How to get indices of diagonal elements of a sparse matrix data array

我有一个 csr 格式的稀疏矩阵,例如:

>>> a = sp.random(3, 3, 0.6, format='csr')  # an example
>>> a.toarray()  # just to see how it looks like
array([[0.31975333, 0.88437035, 0.        ],
       [0.        , 0.        , 0.        ],
       [0.14013856, 0.56245834, 0.62107962]])
>>> a.data  # data array
array([0.31975333, 0.88437035, 0.14013856, 0.56245834, 0.62107962])

对于这个特定的例子,我想得到 [0, 4],它们是非零对角线元素 0.319753330.62107962.

的数据数组索引

执行此操作的简单方法如下:

ind = []
seen = set()
for i, val in enumerate(a.data):
    if val in a.diagonal() and val not in seen:
        ind.append(i)
        seen.add(val)

但实际上矩阵非常大,所以我不想使用 for 循环或使用 toarray() 方法转换为 numpy 数组。有没有更有效的方法呢?

编辑:我刚刚意识到上面的代码在非对角线元素等于和前面一些对角线元素的情况下给出了错误的结果:它returns 该非对角线元素的索引。此外,它不 return 重复对角线元素的索引。例如:

a = np.array([[0.31975333, 0.88437035, 0.        ],
              [0.62107962, 0.31975333, 0.        ],
              [0.14013856, 0.56245834, 0.62107962]])
a = sp.csr_matrix(a)

>>> a.data
array([0.31975333, 0.88437035, 0.62107962, 0.31975333, 0.14013856,
       0.56245834, 0.62107962])

我的代码returnsind = [0, 2],不过应该是[0, 3, 6]。 Andras Deak 提供的代码(他的 get_rowwise 函数),return 是正确的结果。

我找到了一个可能更有效的解决方案,尽管它仍然循环。但是,它遍历矩阵的行而不是元素本身。根据矩阵的稀疏模式,这可能会或可能不会更快。对于具有 N 行的稀疏矩阵,这保证会花费 N 次迭代。

我们只是遍历每一行,通过 a.indicesa.indptr 获取填充的列索引,如果给定行的对角元素存在于填充值中,那么我们计算它的索引:

import numpy as np
import scipy.sparse as sp

def orig_loopy(a):
    ind = []
    seen = set()
    for i, val in enumerate(a.data):
        if val in a.diagonal() and val not in seen:
            ind.append(i)
            seen.add(val)
    return ind

def get_rowwise(a):
    datainds = []
    indices = a.indices # column indices of filled values
    indptr = a.indptr   # auxiliary "pointer" to data indices
    for irow in range(a.shape[0]):
        rowinds = indices[indptr[irow]:indptr[irow+1]] # column indices of the row
        if irow in rowinds:
            # then we've got a diagonal in this row
            # so let's find its index
            datainds.append(indptr[irow] + np.flatnonzero(irow == rowinds)[0])
    return datainds

a = sp.random(300, 300, 0.6, format='csr')
orig_loopy(a) == get_rowwise(a) # True

对于具有相同密度的 (300,300) 形随机输入,原始版本运行时间为 3.7 秒,新版本运行时间为 5.5 毫秒。

方法一

这是一种向量化方法,它首先生成所有非零索引,然后获取行和列索引相同的位置。这有点慢,而且内存占用高。

import numpy as np
import scipy.sparse as sp
import numba as nb

def get_diag_ind_vec(csr_array):
  inds=csr_array.nonzero()
  return np.array(np.where(inds[0]==inds[1])[0])

方法二

循环方法通常在性能方面没有问题,只要您使用编译器,例如。 NumbaCython。我为可能出现的最大对角线元素分配了内存。如果此方法占用大量内存,可以轻松修改。

@nb.jit()
def get_diag_ind(csr_array):
    ind=np.empty(csr_array.shape[0],dtype=np.uint64)
    rowPtr=csr_array.indptr
    colInd=csr_array.indices

    ii=0
    for i in range(rowPtr.shape[0]-1):
      for j in range(rowPtr[i],rowPtr[i+1]):
        if (i==colInd[j]):
          ind[ii]=j
          ii+=1

    return ind[:ii]

时间

csr_array = sp.random(1000, 1000, 0.5, format='csr')

get_diag_ind_vec(csr_array)   -> 8.25ms
get_diag_ind(csr_array)       -> 0.65ms (first call excluded)

这是我的解决方案,似乎比 get_rowwise (Andras Deak) 和 get_diag_ind_vec (max9111) 更快(我不考虑使用 Numba 或 Cython)。

想法是将矩阵(或其副本)的非零对角元素设置为原始矩阵中不存在的某个唯一值x(我选择最大值+1),然后简单地使用 np.where(a.data == x) 到 return 所需的索引。

def diag_ind(a):
    a = a.copy()
    i = a.diagonal() != 0  
    x = np.max(a.data) + 1
    a[i, i] = x
    return np.where(a.data == x)

时间:

A = sp.random(1000, 1000, 0.5, format='csr')

>>> %timeit diag_ind(A)
6.32 ms ± 335 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

>>> %timeit get_diag_ind_vec(A)
14.6 ms ± 292 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

>>> %timeit get_rowwise(A)
24.3 ms ± 5.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

编辑: 复制稀疏矩阵(为了保留原始矩阵)的内存效率不高,因此更好的解决方案是存储对角线元素并在以后使用它们用于恢复原始矩阵。

def diag_ind2(a):
    a_diag = a.diagonal()
    i = a_diag != 0  
    x = np.max(a.data) + 1
    a[i, i] = x
    ind = np.where(a.data == x)
    a[i, i] = a_diag[np.nonzero(a_diag)]
    return ind

这个更快:

>>> %timeit diag_ind2(A)
2.83 ms ± 419 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)