如何批量计算指针网络的交叉熵?

How to batch compute cross entropy for pointer networks?

在指针网络中,输出对数超过了输入的长度。使用此类批次意味着将输入填充到批次输入的最大长度。现在,这一切都很好,直到我们必须计算损失。目前我正在做的是:

logits = stabilize(logits(inputs))     #[batch, max_length]. subtract max(logits) to stabilize
masks = masks(inputs)     #[batch, max_length]. 1 for actual inputs, 0 for padded locations
exp_logits = exp(logits)
exp_logits_masked = exp_logits*masks
probs = exp_logits_masked/sum(exp_logits_masked)

现在我使用这些概率来计算交叉熵

cross_entropy = sum_over_batches(probs[correct_class])

我能做得更好吗?关于处理指针网络的人通常如何完成它的任何想法?

如果我没有可变大小的输入,这一切都可以在 logits 和标签上使用 callable tf.nn.softmax_cross_entropy_with_logits 来实现(这是高度优化的)但是可变长度会产生错误的结果,因为 softmax 计算的分母更大输入中的每个填充 1。

您的方法看起来很准确,据我所知,这也是 RNN 单元中的实现方式。注意 1x 的导数 = dx,0x 的导数 = 0。这会产生你想要的结果,因为你是 summing/averaging 网络末端的梯度。

您唯一可能考虑的是根据屏蔽值的数量重新调整损失。您可能会注意到,当有 0 个屏蔽值时,您的渐变幅度将与许多屏蔽值的幅度略有不同。我不清楚这是否会产生重大影响,但也许会产生非常小的影响。

否则,我自己也使用同样的技术取得了巨大的成功,所以我在这里说你走在正确的轨道上。