Pytorch 的 nn.TransformerEncoder "src_key_padding_mask" 未按预期运行

Pytorch's nn.TransformerEncoder "src_key_padding_mask" not functioning as expected

我正在使用 Pytorch 的 nn.TransformerEncoder 模块。我得到了具有(正常)形状(batch-size, seq-len, emb-dim)的输入样本。一批中的所有样本都被零填充到该批次中最大样本的大小。因此,我希望忽略所有零值的注意力。

文档说,向 nn.TransformerEncoder 模块的 forward 函数添加参数 src_key_padding_mask。这个掩码应该是一个形状为 (batch-size, seq-len) 的张量,并且每个索引都有 True 用于填充零或 False 用于其他任何东西。

我通过以下方式做到了这一点:

. . .

def forward(self, x):
    # x.size -> i.e.: (200, 28, 200)

    mask = (x == 0).cuda().reshape(x.shape[0], x.shape[1])
    # mask.size -> i.e.: (200, 20)

    x = self.embed(x.type(torch.LongTensor).to(device=device))
    x = self.pe(x)

    x = self.transformer_encoder(x, src_key_padding_mask=mask)

    . . .

当我不设置 src_key_padding_mask 时一切正常。但是我得到的错误如下:

File "/home/me/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py", line 4282, in multi_head_attention_forward
    assert key_padding_mask.size(0) == bsz
AssertionError

似乎是在比较掩码的第一个维度,即批量大小,bsz 可能代表批量大小。但是为什么会失败呢?非常感谢帮助!

我遇到了同样的问题,这不是错误:pytorch's Transformer implementation 要求输入 x(seq-len x batch-size x emb-dim) 而你的似乎是 (batch-size x seq-len x emb-dim)