如何创建自定义线性算子来求解 Ax = b?

How to create a customized Linear Operator to solve Ax = b?

我想在 python 中创建一个线性运算符来求解 Ax = b,其中 A 是 float64 的大规模密集矩阵。由于矩阵 A 会导致性能和内存问题,我考虑创建一个自定义运算符,如下所示:

from numpy import ones
from numpy.linalg import inv
import scipy.sparse.linalg
from sklearn.datasets import make_spd_matrix

n = 100


def solver(A, b):
    return inv(A).dot(b)


M = make_spd_matrix(n, random_state=11)
print(M.shape)
solverFunc = scipy.sparse.linalg.LinearOperator((n, n), matvec=solver)

solverFunc.matvec(M, ones((n, 1)))

但是,我收到以下错误:

Traceback (most recent call last):
  File "C:\Users\anoir\Desktop\CG_accelerator\inversion\main.py", line 15, in <module>
    solverFunc = LinearOperator((n, n), matvec=solver)
  File "C:\ProgramData\Anaconda3\envs\inversion\lib\site-packages\scipy\sparse\linalg\interface.py", line 521, in __init__
    self._init_dtype()
  File "C:\ProgramData\Anaconda3\envs\inversion\lib\site-packages\scipy\sparse\linalg\interface.py", line 178, in _init_dtype
    self.dtype = np.asarray(self.matvec(v)).dtype
  File "C:\ProgramData\Anaconda3\envs\inversion\lib\site-packages\scipy\sparse\linalg\interface.py", line 232, in matvec
    y = self._matvec(x)
  File "C:\ProgramData\Anaconda3\envs\inversion\lib\site-packages\scipy\sparse\linalg\interface.py", line 530, in _matvec
    return self.__matvec_impl(x)
TypeError: solver() missing 1 required positional argument: 'b'

这里似乎有什么问题?我按照文档进行操作,但没有关于自定义 LinearOperator 的内容。

线性运算符只接受一个参数。您可以使用如下所示的闭包来解决此问题:

from numpy.linalg import inv
import numpy as np
import scipy.sparse.linalg
from scipy.sparse import random
import timeit

n = 100

def solver_closure(A):
    # This is the outer enclosing function
    def solver(b):
        return inv(A).dot(b)
    return solver  # returns the nested function

M = np.random.rand(n, n)
b = range(n)
print(M.shape)

solverFunc = scipy.sparse.linalg.LinearOperator((n, n), matvec=solver_closure(M))

def test100():
    x = solverFunc.matvec(b)
    print(np.matmul(M,x))

print(timeit.timeit("test100()", setup="from __main__ import test100",number=10))