Pytorch 的 nn.TransformerEncoder "src_key_padding_mask" 未按预期运行
Pytorch's nn.TransformerEncoder "src_key_padding_mask" not functioning as expected
我正在使用 Pytorch 的 nn.TransformerEncoder
模块。我得到了具有(正常)形状(batch-size, seq-len, emb-dim
)的输入样本。一批中的所有样本都被零填充到该批次中最大样本的大小。因此,我希望忽略所有零值的注意力。
文档说,向 nn.TransformerEncoder
模块的 forward
函数添加参数 src_key_padding_mask
。这个掩码应该是一个形状为 (batch-size, seq-len
) 的张量,并且每个索引都有 True
用于填充零或 False
用于其他任何东西。
我通过以下方式做到了这一点:
. . .
def forward(self, x):
# x.size -> i.e.: (200, 28, 200)
mask = (x == 0).cuda().reshape(x.shape[0], x.shape[1])
# mask.size -> i.e.: (200, 20)
x = self.embed(x.type(torch.LongTensor).to(device=device))
x = self.pe(x)
x = self.transformer_encoder(x, src_key_padding_mask=mask)
. . .
当我不设置 src_key_padding_mask
时一切正常。但是我得到的错误如下:
File "/home/me/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py", line 4282, in multi_head_attention_forward
assert key_padding_mask.size(0) == bsz
AssertionError
似乎是在比较掩码的第一个维度,即批量大小,bsz
可能代表批量大小。但是为什么会失败呢?非常感谢帮助!
我遇到了同样的问题,这不是错误:pytorch's Transformer implementation 要求输入 x
为 (seq-len x batch-size x emb-dim)
而你的似乎是 (batch-size x seq-len x emb-dim)
。
我正在使用 Pytorch 的 nn.TransformerEncoder
模块。我得到了具有(正常)形状(batch-size, seq-len, emb-dim
)的输入样本。一批中的所有样本都被零填充到该批次中最大样本的大小。因此,我希望忽略所有零值的注意力。
文档说,向 nn.TransformerEncoder
模块的 forward
函数添加参数 src_key_padding_mask
。这个掩码应该是一个形状为 (batch-size, seq-len
) 的张量,并且每个索引都有 True
用于填充零或 False
用于其他任何东西。
我通过以下方式做到了这一点:
. . .
def forward(self, x):
# x.size -> i.e.: (200, 28, 200)
mask = (x == 0).cuda().reshape(x.shape[0], x.shape[1])
# mask.size -> i.e.: (200, 20)
x = self.embed(x.type(torch.LongTensor).to(device=device))
x = self.pe(x)
x = self.transformer_encoder(x, src_key_padding_mask=mask)
. . .
当我不设置 src_key_padding_mask
时一切正常。但是我得到的错误如下:
File "/home/me/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py", line 4282, in multi_head_attention_forward
assert key_padding_mask.size(0) == bsz
AssertionError
似乎是在比较掩码的第一个维度,即批量大小,bsz
可能代表批量大小。但是为什么会失败呢?非常感谢帮助!
我遇到了同样的问题,这不是错误:pytorch's Transformer implementation 要求输入 x
为 (seq-len x batch-size x emb-dim)
而你的似乎是 (batch-size x seq-len x emb-dim)
。