在 Python 中计算伪逆 (pinv) 的最快方法

Fastest way for computing pseudoinverse (pinv) in Python

我有一个循环,我在其中计算相当大的非稀疏矩阵(例如 20000x800)的几个伪逆。

由于我的代码大部分时间花在 pinv 上,我试图找到一种加速计算的方法。我已经在多个进程中使用多处理 (joblib/loky) 到 运行,但这当然也会增加开销。使用 jit 并没有多大帮助。

有没有更快的方法/更好的实现来使用任何函数计算伪逆?精度不是关键。

我目前的基准

import time
import numba
import numpy as np
from numpy.linalg import pinv as np_pinv
from scipy.linalg import pinv as scipy_pinv
from scipy.linalg import pinv2 as scipy_pinv2

@numba.njit
def np_jit_pinv(A):
  return np_pinv(A)

matrix = np.random.rand(20000, 800)
for pinv in [np_pinv, scipy_pinv, scipy_pinv2, np_jit_pinv]:
    start = time.time()
    pinv(matrix)
    print(f'{pinv.__module__ +"."+pinv.__name__} took {time.time()-start:.3f}')
numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446

编辑: JAX 似乎快了 30%!感人的!感谢您让我知道@yuri-brigance。对于 Windows 它在 WSL 下运行良好。

numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446
jax._src.numpy.linalg.pinv took 0.995

尝试使用 JAX:

import jax.numpy as jnp

jnp.linalg.pinv(A)

似乎比常规 numpy.linalg.pinv 稍微快一点。在我的机器上你的基准看起来像这样:

jax._src.numpy.linalg.pinv took 3.127
numpy.linalg.pinv took 4.284