乘以许多矩阵和许多向量pytorch

multiply many matrices and many vectors pytorch

我正在尝试将以下内容相乘:

一批矩阵N x M x D
一批向量 N x D x 1
得到结果:N x M x 1

好像我在 M x D D x 1.

上做 N 点积

我似乎无法在 PyTorch 中找到正确的函数。

torch.bmm 据我所知仅适用于一批向量和单个矩阵。如果我必须使用 torch.einsum 那么就这样吧,但 id 而不是!

使用 einsum 非常简单直观:

torch.einsum('ijk, ikl->ijl', mats, vecs)

但是你的操作只是:

mats @ vecs