如何使用 Pytorch LinAlg 求解器?

How to use Pytorch LinAlg Solver?

我正在尝试在 https://pytorch.org/docs/stable/generated/torch.linalg.solve.html

上复制第一个示例
import torch
import time


Acuda = torch.randn(2,3,3,device='cuda')
bcuda = torch.randn(2,3,4,device='cuda')


t1 = time.time()
torch.linalg.torch.solve(Acuda,bcuda)

print('torch took: ',time.time()-t1)

结果我得到

Traceback (most recent call last):
 File "linalg_solver_test.py", line 10, in <module>
     torch.linalg.torch.solve(Acuda,bcuda) 
     RuntimeError: A must be batches of square matrices, but they are 4 by 3 matrices

我的 Pytorch 版本是 1.7.1。 与文档页面上的示例相反,我使用的是 torch.linalg.torch.solve 因为 torch.linalg.solve 不存在。

您应该为 LinAlg 使用最新的 PyTorch 1.9,因为它明确提到“支持科学计算的重大改进,包括 torch.linalg”(https://github.com/pytorch/pytorch/releases/tag/v1.9.0)

PyTorch 1.7.1 比较旧。貌似这个版本的LinAlg求解器不支持非方阵

这个老版本解AX = B的函数中参数A,B的顺序其实就是B,A。 https://pytorch.org/docs/stable/generated/torch.solve.html