MultiheadAttention 的可学习参数数量

Number of learnable parameters of MultiheadAttention

在测试时(使用 PyTorch 的 MultiheadAttention),我注意到增加或减少多头注意力的头数不会改变我模型的可学习参数总数。

这种行为是否正确?如果是,为什么?

heads 的数量不应该影响模型可以学习的参数数量吗?

多头注意力的标准实现是将模型的维度除以注意力头的数量。

具有单个注意力头的维度 d 模型会将嵌入投影到 d 维查询、键和值张量(每个投影计数 d2 参数,不包括偏差,总共 3d2).

具有 k 个注意力头的相同维度的模型会将嵌入投影到 k[=53 的三元组=]维查询,键值张量(每个投影计数d×d/k=d2/k 参数,不包括偏差,总共 3kd2/k=3d2).


参考文献:

来自原文:

您引用的 Pytorch 实现: