Transformer编码器中的Query padding mask和key padding mask

Query padding mask and key padding mask in Transformer encoder

我正在使用 pytorch 在 transformer 编码器中实现自注意力部分 nn.MultiheadAttention 并在 transformer 的填充掩码中混淆。

下图为query(行)和key(列)的self-attention权重

如您所见,有一些标记“”,我已经在密钥中屏蔽了它。因此代币不会计算注意力权重。

还有两个问题:

  1. 在查询部分,除了红色方块部分,我是否也可以屏蔽它们(“”)?这合理吗?

  2. 如何在查询中屏蔽“”?

通过在 src_masksrc_key_padding_mask 参数中给出掩码,注意力权重也沿行使用 softmax 函数。如果我将所有“”行设置为 -infsoftmax 将 return nan 并且损失为 nan

在自我注意期间不需要屏蔽查询,如果不使用网络中对应于 <PAD> 个标记的状态就足够了(作为隐藏状态或 keys/values), 它们不会影响损失函数或网络中的任何其他内容。

如果你想确保你没有犯错误导致梯度流过 <PAD> 标记,你可以在计算后使用 torch.where 显式地将自注意力归零.