MultiHeadAttention 中 att_mask 和 key_padding_mask 有什么区别
what the difference between att_mask and key_padding_mask in MultiHeadAttnetion
pytorch的MultiHeadAttnetion
中att_mask
和key_padding_mask
有什么区别:
key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored
attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.
提前致谢。
key_padding_mask
用于屏蔽填充的位置,即在输入序列结束之后。这始终特定于输入批次,并且取决于批次中的序列与最长序列相比有多长。它是形状为 批量大小 × 输入长度 .
的二维张量
另一方面,attn_mask
表示哪些键值对有效。在 Transformer 解码器中,三角形掩码用于模拟推理时间并防止关注“未来”位置。这就是 att_mask
通常的用途。如果是2D张量,shape是input length×input length。您还可以拥有一个特定于批次中每个项目的掩码。在这种情况下,您可以使用形状为 (batch size × num heads) × input length × input length 的 3D 张量。 (因此,理论上,您可以用 3D att_mask
模拟 key_padding_mask
。)
我认为它们的工作原理是一样的:两个掩码都定义了查询和键之间的哪些注意力不会被使用。而这两种选择的唯一区别在于你更愿意输入哪种形状的面具
根据代码,这两个mask好像是merged/taken union 所以他们都起着同样的作用——不会用到query和key之间的attention。因为它们是联合的:如果您需要使用两个掩码,则两个掩码输入可以具有不同的值,或者您可以根据需要的形状方便地输入 mask_args 中的掩码:这是的一部分函数 multi_head_attention_forward()
中第 5227 行附近 pytorch/functional.py 的原始代码
...
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
elif attn_mask.dtype == torch.bool:
attn_mask = attn_mask.logical_or(key_padding_mask)
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
如有不同意见或我说的不对请指正
pytorch的MultiHeadAttnetion
中att_mask
和key_padding_mask
有什么区别:
key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored
attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.
提前致谢。
key_padding_mask
用于屏蔽填充的位置,即在输入序列结束之后。这始终特定于输入批次,并且取决于批次中的序列与最长序列相比有多长。它是形状为 批量大小 × 输入长度 .
另一方面,attn_mask
表示哪些键值对有效。在 Transformer 解码器中,三角形掩码用于模拟推理时间并防止关注“未来”位置。这就是 att_mask
通常的用途。如果是2D张量,shape是input length×input length。您还可以拥有一个特定于批次中每个项目的掩码。在这种情况下,您可以使用形状为 (batch size × num heads) × input length × input length 的 3D 张量。 (因此,理论上,您可以用 3D att_mask
模拟 key_padding_mask
。)
我认为它们的工作原理是一样的:两个掩码都定义了查询和键之间的哪些注意力不会被使用。而这两种选择的唯一区别在于你更愿意输入哪种形状的面具
根据代码,这两个mask好像是merged/taken union 所以他们都起着同样的作用——不会用到query和key之间的attention。因为它们是联合的:如果您需要使用两个掩码,则两个掩码输入可以具有不同的值,或者您可以根据需要的形状方便地输入 mask_args 中的掩码:这是的一部分函数 multi_head_attention_forward()
...
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
elif attn_mask.dtype == torch.bool:
attn_mask = attn_mask.logical_or(key_padding_mask)
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
如有不同意见或我说的不对请指正