从分隔字符串创建 scipy.sparse.csr_matrix

create a scipy.sparse.csr_matrix from a delimited string

我正在处理大量二进制数据,这些数据以看起来像 \t\t1\t\t1\t\t\t(但更长)的字符串形式逐行进入我的程序。可以想象,这些是制表符分隔文件中的行。

显然,我可以执行 '\t\t1\t\t1\t\t\t'.split('\t') 并获得 1'' 的列表,我可以轻松地将其转换为 1 和 0 或 T/F 管他呢。但是,数据非常稀疏(很多 0 而不是很多 1)所以我希望使用某种稀疏表示。

我的问题是:有人知道有什么方法可以直接从这个字符串转到 scipy.sparse.csr_matrix() 之类的东西,而无需 首先创建一个中间密集矩阵吗?

我尝试将拆分字符串(即 1'' 的列表)直接传递给 csr_matrix(),我得到了 TypeError: no supported conversion for types: (dtype('<U1'),)

正如我所说,我可以执行上述操作并得到 1 和 0,然后将 that 转换为 csr_matrix() 但是我失去了所有的速度和稀疏的内存优势,因为我正在创建完全密集的版本。

scipy 无法解释您的输入,因为它不知道您希望将空字符串转换为 0。这很好用:

>>> from scipy.sparse import csr_matrix
>>> x = [0 if not a else int(a) for a in "\t\t\t\t1\t\t\t1\t\t\t".split('\t')] 
>>> csr_matrix(x)
<1x11 sparse matrix of type '<class 'numpy.int64'>'
        with 2 stored elements in Compressed Sparse Row format>

确保你的列表在矩阵化之前都是 numbrt 格式。

根据 OP 评论我记得你可以强制将空字符串转换为 0,所以更好的解决方案

>>> csr_matrix("\t\t\t\t1\t\t\t1\t\t\t".split('\t'),dtype=np.int64)
<1x11 sparse matrix of type '<class 'numpy.int64'>'
        with 2 stored elements in Compressed Sparse Row format>

少生成一个列表。

下面是逐行处理数据的方法:

In [32]: astr = '\t\t1\t\t1\t\t\t'      # sample row
In [33]: row, col = [],[]
In [34]: for i in range(5):
    ...:     c = [j for j,v in enumerate(astr.split('\t')) if v]
    ...:     row.extend([i]*len(c))
    ...:     col.extend(c)
    ...: data = np.ones(len(col),'int32')
    ...: M = sparse.csr_matrix((data, (row, col)))
    ...: 
In [35]: M
Out[35]: 
<5x5 sparse matrix of type '<class 'numpy.int32'>'
    with 10 stored elements in Compressed Sparse Row format>
In [36]: M.A
Out[36]: 
array([[0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1],
       [0, 0, 1, 0, 1]], dtype=int32)

对于每一行,我只收集“1”的索引。从这些我构建相应的 datarow 列表(或数组)。理论上我可以构造 indptr 来更直接地 csr 创建,但是 coo 样式更容易理解。

中间值为:

In [40]: c
Out[40]: [2, 4]
In [41]: row
Out[41]: [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
In [42]: col
Out[42]: [2, 4, 2, 4, 2, 4, 2, 4, 2, 4]
In [43]: data
Out[43]: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)

获取 c 值的另一种方法是:

In [46]: np.where(astr.split('\t'))[0]
Out[46]: array([2, 4])

(但列表理解速度更快)。

字符串和列表 find/index 方法找到第一项,但不是全部。