可变长度序列上的 RNN 注意力权重是否应该重新归一化为 "mask" 零填充的影响?

Should RNN attention weights over variable length sequences be re-normalized to "mask" the effects of zero-padding?

明确地说,我指的是 Hierarchical Attention Networks for Document Classification and implemented many places, for example: here 中描述的类型 "self-attention"。我 不是 指的是编码器-解码器模型(即 Bahdanau)中使用的 seq2seq 注意力类型,尽管我的问题也可能适用于此......我只是不太熟悉它。

Self-attention 基本上只是计算 RNN 隐藏状态的加权平均值(均值池的推广,即未加权平均值)。当同一批次中有可变长度序列时,它们通常会被零填充到批次中最长序列的长度(如果使用动态 RNN)。当为每个序列计算注意力权重时,最后一步是 softmax,因此注意力权重总和为 1.

但是,在我见过的每个注意力实现中,都没有注意屏蔽或以其他方式取消零填充对注意力权重的影响。这对我来说似乎是错误的,但我担心我可能遗漏了一些东西,因为似乎没有其他人对此感到困扰。

例如,考虑一个长度为 2 的序列,用零填充到长度 5。最终这导致注意力权重被计算为类似的 0 填充向量的 softmax,例如:

weights = softmax([0.1, 0.2, 0, 0, 0]) = [0.20, 0.23, 0.19, 0.19, 0.19]

并且因为 exp(0)=1,零填充实际上 "waters down" 注意力权重。这可以很容易地解决,在 softmax 操作之后,通过将权重与二进制掩码相乘,即

mask = [1, 1, 0, 0, 0]

然后将权重重新归一化为 1。这将导致:

weights = [0.48, 0.52, 0, 0, 0]

当我这样做时,我几乎 总是 看到性能提升(在我的模型的准确性方面 - 我正在做文档 classification/regression)。那么为什么没有人这样做呢?

有一段时间我认为可能重要的是注意力权重的相对值(即比率),因为梯度不会通过零-无论如何填充。但是,如果归一化无关紧要,那么我们为什么要使用 softmax 而不是仅仅使用 exp(.)? (另外,这并不能解释性能提升...)

好问题!我相信您的担忧是有效的,填充编码器输出的注意力分数为零 确实会影响 注意力。但是,您必须牢记几个方面:

  • 有不同的评分函数,tf-rnn-attention中的函数使用简单的线性+tanh+线性变换。但即使是这个评分函数也可以学会输出负分。如果您查看代码并想象 inputs 由零组成,向量 v 由于偏差不一定为零,并且与 u_omega 的点积可以进一步将其提升到较低的负数(在换句话说,具有非线性的普通简单 NN 可以做出正预测和负预测)。低负分不会冲淡 softmax 中的高分。

  • 由于分桶技术,一个桶内的序列通常具有大致相同的长度,因此不太可能有一半的输入序列用零填充.当然,它并没有解决任何问题,它只是意味着在实际应用中,填充的负面影响自然是有限的。

  • 你最后提到了,但我也想强调一下:最终参与输出是编码器输出的加权和,即相对 值实际上很重要。拿你自己的例子来计算这种情况下的加权和:

    • 第一个是0.2 * o1 + 0.23 * o2(其余为零)
    • 第二个是0.48 * o1 + 0.52 * o2(其余也是零)


    是的,第二个向量的大小是原来的两倍,这不是关键问题,因为它随后进入线性层。但是 o2 上的相对关注度仅比使用掩码时高出 7%。

    这意味着即使注意力权重在学习忽略零输出方面做得不好,对输出向量的最终影响仍然足以让解码器考虑正确的输出, 在这种情况下要专注于 o2.

希望这能让您相信重新归一化并不是那么重要,但如果实际应用可能会加快学习速度。

BERT implementation 应用填充掩码来计算注意力分数。 将 0 添加到非填充注意力分数并将 -10000 添加到填充注意力分数。 e^-10000 与其他注意力得分值相比非常小 w.r.t。

attention_score = [0.1, 0.2, 0, 0, 0]
mask = [0, 0, -10000, -10000] # -10000 is a large negative value 
attention_score += mask
weights = softmax(attention_score)