torch.matmul 给出 RuntimeError

torch.matmul gives RuntimeError

我有两个张量

t1=torch.Size([400, 32, 400])
t2= torch.Size([400, 32, 32])

当我执行这个 torch.matmul(t1,t2)

我收到这个错误 RuntimeError:

Expected tensor to have size 400 at dimension 1, but got size 32 for argument #2 'batch2' (while checking arguments for bmm)

任何帮助将不胜感激

你得到这个错误是因为矩阵乘法的顺序不对。

应该是:

a = torch.randn(400, 32, 400)
b = torch.randn(400, 32, 32)
out = torch.matmul(b, a) # You performed torch.matmul(a, b)
# You can also do a simpler version of the matrix multiplication using the below code
out = b @ a