如何为空白令牌预测计算变压器损失?

How is transformers loss calculated for blank token predictions?

我目前正在尝试实现一个转换器,但无法理解其损耗计算。

我的编码器输入查找 batch_size=1 和 max_sentence_length=8,例如:

[[Das, Wetter, ist, gut, <blank>, <blank>, <blank>, <blank>]]

我的解码器输入看起来像(德语到英语):

[[<start>, The, weather, is, good, <end>, <blank>, <blank>]]

假设我的转换器预测了那些 class 概率(仅显示具有最高 class 概率的 class 的词):

[[The, good, is, weather, <end>, <blank>, <blank>, <blank>]]

现在我使用以下方法计算损失:

loss = categorical_crossentropy(
   [[The, good, is, weather, <end>, <blank>, <blank>, <blank>]],
   [[The, weather, is, good, <end>, <blank>, <blank>, <blank>]]
)

这是计算损失的正确方法吗?我的转换器总是预测下一个单词的空白标记,我认为那是因为我的损失计算有误,必须在计算损失之前对空白标记做一些事情。

您需要屏蔽填充。 (你所谓的<blank>更常被称为<pad>。)

  • 创建一个掩码说明有效标记的位置(伪代码:mask = target != '<pad>')

  • 计算分类交叉熵时,不自动减少损失,保持原值

  • 将损失值与mask相乘,即<blank>个token对应的位置取零,并在有效位置求和损失。 (伪代码:loss_sum = (loss * mask).sum()

  • loss_sum除以有效位置的个数,即掩码之和(伪码:loss = loss_sum / mask.sum()