为什么 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次矩阵

Why W_q matrix in torch.nn.MultiheadAttention is quadratic

我正尝试在我的网络中实施 nn.MultiheadAttention。根据 docs,

embed_dim  – total dimension of the model.

但是,根据 source file

embed_dim must be divisible by num_heads

self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

如果我理解正确的话,这意味着每个头只每个查询的一部分特征,因为矩阵是二次的。是实现的bug还是我的理解有误?

每个头使用投影查询向量的不同部分。你可以想象它好像查询被分成 num_heads 个向量,这些向量独立地用于计算缩放的 dot-product 注意力。因此,每个头对查询中的特征(以及键和值)的不同线性组合进行操作。此线性投影是使用 self.q_proj_weight 矩阵完成的,投影查询将传递给 F.multi_head_attention_forward 函数。

F.multi_head_attention_forward, it is implemented by reshaping and transposing the query vector, so that the independent attentions for individual heads can be computed efficiently by matrix multiplication.

注意力头大小是 PyTorch 的设计决策。理论上,您可以有不同的头部尺寸,因此投影矩阵的形状为 embedding_dim × num_heads * head_dims。转换器的某些实现(例如基于 C++ 的 Marian for machine translation, or Huggingface's Transformers)允许这样做。