Multi Head Attention:正确实现 Q、K、V 的线性变换

Multi Head Attention: Correct implementation of Linear Transformations of Q, K, V

我现在正在 Pytorch 中实现 Multi-Head Self-Attention。我查看了几个实现,它们似乎有点不对,或者至少我不确定为什么要这样做。他们通常只应用一次线性投影:

    self.query_projection = nn.Linear(input_dim, output_dim)
    self.key_projection = nn.Linear(input_dim, output_dim)
    self.value_projection = nn.Linear(input_dim, output_dim)

然后他们经常将投影重塑为

    query_heads = query_projected.view(batch_size, query_lenght, head_count, head_dimension).transpose(1,2)
    key_heads = key_projected.view(batch_size, key_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, key_len, d_head)
    value_heads = value_projected.view(batch_size, value_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, value_len, d_head)

    attention_weights = scaled_dot_product(query_heads, key_heads) 

根据此代码,每个负责人将工作 一个预计查询。然而,最初的论文说我们需要为编码器中的每个头设置不同的线性投影。

这个显示的实现是否正确?

它们是等价的。

理论上(以及在论文写作中),将它们视为单独的线性投影更容易。假设你有 8 个头,每个头都有一个 M->N 投影,那么其中一个会有 8 N by M 矩阵。

虽然在实现中,通过 8N by M 矩阵进行 M->8N 转换会更快。

可以连接第一个公式中的矩阵以获得第二个公式中的矩阵。