Pytorch 或 Numpy 批量矩阵操作

Pytorch or Numpy Batch Matrix Operation

我正在尝试 torch.bmm 进行以下矩阵运算,

如果matrix是一个M * N张量,batch是一个N * B张量,我该如何实现, 在每个批次中,矩阵@batch_i,给出M,将批次大小放在一起,输出张量看起来像M * B

这里有两个问题,

1.To使用torch.bmm,好像两个矩阵都需要batch,但是我的第一个输入不是

  1. batch size需要是第一个维度,而我的batch size在最后

我想这对 Numpy 用户来说也是同样的问题

好像torch.einsum('ij,jbc->ibc', A, B)就能解决问题