如何使用 Pytorch LinAlg 求解器?
How to use Pytorch LinAlg Solver?
我正在尝试在 https://pytorch.org/docs/stable/generated/torch.linalg.solve.html
上复制第一个示例
import torch
import time
Acuda = torch.randn(2,3,3,device='cuda')
bcuda = torch.randn(2,3,4,device='cuda')
t1 = time.time()
torch.linalg.torch.solve(Acuda,bcuda)
print('torch took: ',time.time()-t1)
结果我得到
Traceback (most recent call last):
File "linalg_solver_test.py", line 10, in <module>
torch.linalg.torch.solve(Acuda,bcuda)
RuntimeError: A must be batches of square matrices, but they are 4 by 3 matrices
我的 Pytorch 版本是 1.7.1。
与文档页面上的示例相反,我使用的是 torch.linalg.torch.solve
因为 torch.linalg.solve
不存在。
您应该为 LinAlg 使用最新的 PyTorch 1.9,因为它明确提到“支持科学计算的重大改进,包括 torch.linalg”(https://github.com/pytorch/pytorch/releases/tag/v1.9.0)
PyTorch 1.7.1 比较旧。貌似这个版本的LinAlg求解器不支持非方阵
这个老版本解AX = B的函数中参数A,B的顺序其实就是B,A。
https://pytorch.org/docs/stable/generated/torch.solve.html
我正在尝试在 https://pytorch.org/docs/stable/generated/torch.linalg.solve.html
上复制第一个示例import torch
import time
Acuda = torch.randn(2,3,3,device='cuda')
bcuda = torch.randn(2,3,4,device='cuda')
t1 = time.time()
torch.linalg.torch.solve(Acuda,bcuda)
print('torch took: ',time.time()-t1)
结果我得到
Traceback (most recent call last):
File "linalg_solver_test.py", line 10, in <module>
torch.linalg.torch.solve(Acuda,bcuda)
RuntimeError: A must be batches of square matrices, but they are 4 by 3 matrices
我的 Pytorch 版本是 1.7.1。
与文档页面上的示例相反,我使用的是 torch.linalg.torch.solve
因为 torch.linalg.solve
不存在。
您应该为 LinAlg 使用最新的 PyTorch 1.9,因为它明确提到“支持科学计算的重大改进,包括 torch.linalg”(https://github.com/pytorch/pytorch/releases/tag/v1.9.0)
PyTorch 1.7.1 比较旧。貌似这个版本的LinAlg求解器不支持非方阵
这个老版本解AX = B的函数中参数A,B的顺序其实就是B,A。 https://pytorch.org/docs/stable/generated/torch.solve.html