在两个数组之间执行矩阵乘法并仅在被屏蔽的地方获得结果

Perform matrix multiplication between two arrays and get result only on masked places

我有两个密集矩阵,A [200000,10],B [10,100000]。我需要将它们相乘得到矩阵 C。我不能直接这样做,因为生成的矩阵不适合内存。此外,我只需要结果矩阵中的几个元素,比如元素总数的 1-2%。我有第三个矩阵 W [200000,100000] 它是稀疏的并且在矩阵 C[= 我感兴趣的那些地方有非零元素19=]。 有没有办法将 W 用作 "mask" 以便生成的矩阵 C 将是稀疏的并且仅包含所需的元素?

首先得到W中非零位置的索引,然后将A中第i行与第j列相乘即可得到结果矩阵的(i,j)元素在 B 中,将结果保存为元组 (i,j,res) 而不是将其保存为矩阵(这是保存稀疏矩阵的正确方法)。

由于矩阵乘法只是 table 的点积,我们可以以向量化的方式执行我们需要的特定点积。

import numpy as np
import scipy as sp

iX, iY = sp.nonzero(W)
values = np.sum(A[iX]*B[:, iY].T, axis=-1) #batched dot product
C = sp.sparse.coo_matrix(values, np.asarray([iX,iY]).T)

这是使用 np.einsum 进行矢量化解决方案的一种方法 -

from scipy import sparse
from scipy.sparse import coo_matrix

# Get row, col for the output array
r,c,_= sparse.find(W)

# Get the sum-reduction using valid rows and corresponding cols from A, B
out = np.einsum('ij,ji->i',A[r],B[:,c])

# Store as sparse matrix
out_sparse = coo_matrix((out, (r, c)), shape=W.shape)

示例 运行 -

1) 输入:

In [168]: A
Out[168]: 
array([[4, 6, 1, 1, 1],
       [0, 8, 1, 3, 7],
       [2, 8, 3, 2, 2],
       [3, 4, 1, 6, 3]])

In [169]: B
Out[169]: 
array([[5, 2, 4],
       [2, 1, 3],
       [7, 7, 2],
       [5, 7, 5],
       [8, 5, 0]])

In [176]: W
Out[176]: 
<4x3 sparse matrix of type '<type 'numpy.bool_'>'
    with 5 stored elements in Compressed Sparse Row format>

In [177]: W.toarray()
Out[177]: 
array([[ True, False, False],
       [False, False, False],
       [ True,  True, False],
       [ True, False,  True]], dtype=bool)

2) 使用密集数组进行直接计算,稍后验证结果:

In [171]: (A.dot(B))*W.toarray()
Out[171]: 
array([[52,  0,  0],
       [ 0,  0,  0],
       [73, 57,  0],
       [84,  0, 56]])

3) 使用建议的代码并获得稀疏矩阵输出:

In [172]: # Using proposed codes
     ...: r,c,_= sparse.find(W)
     ...: out = np.einsum('ij,ji->i',A[r],B[:,c])
     ...: out_sparse = coo_matrix((out, (r, c)), shape=W.shape)
     ...: 

4) 最后通过转换为 dense/array 版本并检查直接版本来验证结果 -

In [173]: out_sparse.toarray()
Out[173]: 
array([[52,  0,  0],
       [ 0,  0,  0],
       [73, 57,  0],
       [84,  0, 56]])