Transformer编码器中的Query padding mask和key padding mask
Query padding mask and key padding mask in Transformer encoder
我正在使用 pytorch 在 transformer 编码器中实现自注意力部分 nn.MultiheadAttention
并在 transformer 的填充掩码中混淆。
下图为query(行)和key(列)的self-attention权重
如您所见,有一些标记“”,我已经在密钥中屏蔽了它。因此代币不会计算注意力权重。
还有两个问题:
在查询部分,除了红色方块部分,我是否也可以屏蔽它们(“”)?这合理吗?
如何在查询中屏蔽“”?
通过在 src_mask
或 src_key_padding_mask
参数中给出掩码,注意力权重也沿行使用 softmax
函数。如果我将所有“”行设置为 -inf
,softmax
将 return nan
并且损失为 nan
在自我注意期间不需要屏蔽查询,如果不使用网络中对应于 <PAD>
个标记的状态就足够了(作为隐藏状态或 keys/values), 它们不会影响损失函数或网络中的任何其他内容。
如果你想确保你没有犯错误导致梯度流过 <PAD>
标记,你可以在计算后使用 torch.where
显式地将自注意力归零.
我正在使用 pytorch 在 transformer 编码器中实现自注意力部分 nn.MultiheadAttention
并在 transformer 的填充掩码中混淆。
下图为query(行)和key(列)的self-attention权重
如您所见,有一些标记“
还有两个问题:
在查询部分,除了红色方块部分,我是否也可以屏蔽它们(“
”)?这合理吗? 如何在查询中屏蔽“
”?
通过在 src_mask
或 src_key_padding_mask
参数中给出掩码,注意力权重也沿行使用 softmax
函数。如果我将所有“-inf
,softmax
将 return nan
并且损失为 nan
在自我注意期间不需要屏蔽查询,如果不使用网络中对应于 <PAD>
个标记的状态就足够了(作为隐藏状态或 keys/values), 它们不会影响损失函数或网络中的任何其他内容。
如果你想确保你没有犯错误导致梯度流过 <PAD>
标记,你可以在计算后使用 torch.where
显式地将自注意力归零.