神经机器翻译模型预测偏差一
Neural Machine Translation model predictions are off-by-one
问题总结
在下面的示例中,我的 NMT 模型具有高损失,因为它正确预测了 target_input
而不是 target_output
。
Targetin : 1 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 9 10 10 10 3 3 10 10 3 10 3 3 10 10 3 9 9 4 4 4 4 4 3 10 3 3 9 9 3 6 6 6 6 6 6 10 9 9 10 10 4 4 4 4 4 4 4 4 4 4 4 4 9 9 9 9 3 3 3 6 6 6 6 6 9 9 10 3 4 4 4 4 4 4 4 4 4 4 4 4 9 9 10 3 10 9 9 3 4 4 4 4 4 4 4 4 4 10 10 4 4 4 4 4 4 4 4 4 4 9 9 10 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 9 3 3 10 6 6 6 6 6 3 9 9 3 3 3 3 3 3 3 10 10 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 5 3 3 3 3 10 10 10 3 9 9 5 10 3 3 3 3 9 9 9 5 10 10 10 10 10 4 4 4 4 3 10 6 6 6 6 6 6 3 5 10 10 10 10 3 9 9 6 6 6 6 6 6 6 6 6 9 9 9 3 3 3 6 6 6 6 6 6 6 6 3 9 9 9 3 3 6 6 6 3 3 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
Targetout : 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 9 10 10 10 3 3 10 10 3 10 3 3 10 10 3 9 9 4 4 4 4 4 3 10 3 3 9 9 3 6 6 6 6 6 6 10 9 9 10 10 4 4 4 4 4 4 4 4 4 4 4 4 9 9 9 9 3 3 3 6 6 6 6 6 9 9 10 3 4 4 4 4 4 4 4 4 4 4 4 4 9 9 10 3 10 9 9 3 4 4 4 4 4 4 4 4 4 10 10 4 4 4 4 4 4 4 4 4 4 9 9 10 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 9 3 3 10 6 6 6 6 6 3 9 9 3 3 3 3 3 3 3 10 10 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 5 3 3 3 3 10 10 10 3 9 9 5 10 3 3 3 3 9 9 9 5 10 10 10 10 10 4 4 4 4 3 10 6 6 6 6 6 6 3 5 10 10 10 10 3 9 9 6 6 6 6 6 6 6 6 6 9 9 9 3 3 3 6 6 6 6 6 6 6 6 3 9 9 9 3 3 6 6 6 3 3 3 3 3 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
Prediction : 3 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 3 3 3 3 3 3 10 3 3 10 3 3 10 3 3 9 3 4 4 4 4 4 3 10 3 3 9 3 3 6 6 6 6 6 6 10 9 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 3 3 3 6 6 6 6 6 9 6 3 3 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 10 9 3 3 4 4 4 4 4 4 4 4 4 3 10 4 4 4 4 4 4 4 4 4 4 9 3 3 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 10 6 6 6 6 6 3 9 3 3 3 3 3 3 3 3 3 3 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 3 3 3 3 3 10 3 3 3 9 3 3 10 3 3 3 3 9 3 9 3 10 3 3 3 3 4 4 4 4 3 10 6 6 6 6 6 6 3 3 10 3 3 3 3 9 3 6 6 6 6 6 6 6 6 6 9 6 9 3 3 3 6 6 6 6 6 6 6 6 3 9 3 9 3 3 6 6 6 3 3 3 3 3 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6
Source : 9 16 4 7 22 22 19 1 12 19 12 18 5 18 9 18 5 8 12 19 19 5 5 19 22 7 12 12 6 19 7 3 20 7 9 14 4 11 20 12 7 1 18 7 7 5 22 9 13 22 20 19 7 19 7 13 7 11 19 20 6 22 18 17 17 1 12 17 23 7 20 1 13 7 11 11 22 7 12 1 13 12 5 5 19 22 5 5 20 1 5 4 12 9 7 12 8 14 18 22 18 12 18 17 19 4 19 12 11 18 5 9 9 5 14 7 11 6 4 17 23 6 4 5 12 6 7 14 4 20 6 8 12 25 4 19 6 1 5 1 5 20 4 18 12 12 1 11 12 1 25 13 18 19 7 12 7 3 4 22 9 9 12 4 8 9 19 9 22 22 19 1 19 7 5 19 4 5 18 11 13 9 4 14 12 13 20 11 12 11 7 6 1 11 19 20 7 22 22 12 22 22 9 3 8 12 11 14 16 4 11 7 11 1 8 5 5 7 18 16 22 19 9 20 4 12 18 7 19 7 1 12 18 17 12 19 4 20 9 9 1 12 5 18 14 17 17 7 4 13 16 14 12 22 12 22 18 9 12 11 3 18 6 20 7 4 20 7 9 1 7 25 13 5 25 14 11 5 20 7 23 12 5 16 19 19 25 19 7 -1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
很明显,预测与 target_input
几乎 100% 匹配,而不是 target_output
,这是应该的(差一)。正在使用 target_output
计算损失和梯度,因此奇怪的是预测匹配到 target_input
。
模型概览
NMT 模型使用源语言中的主要单词序列预测目标语言中的单词序列。这是 Google 翻译背后的框架。由于 NMT 使用耦合 RNN,因此它受到监督并需要标记目标输入和输出。
NMT 使用source
序列、target_input
序列和target_output
序列。在下面的例子中,编码器 RNN(蓝色)使用源输入词产生一个意义向量,它传递给解码器 RNN(红色),解码器使用意义向量产生输出。
在进行新的预测(推理)时,解码器 RNN 使用其自己的先前输出为时间步长中的下一个预测播种。然而,为了改进训练,允许在每个新的时间步使用正确的先前预测为自己播种。这就是为什么 target_input
是训练所必需的。
获取带有源的迭代器的代码,target_in,target_out
def get_batched_iterator(hparams, src_loc, tgt_loc):
if not (os.path.exists('primary.csv') and os.path.exists('secondary.csv')):
utils.integerize_raw_data()
source_dataset = tf.data.TextLineDataset(src_loc)
target_dataset = tf.data.TextLineDataset(tgt_loc)
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.shuffle(hparams.shuffle_buffer_size, seed=hparams.shuffle_seed)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([hparams.sos], target), axis=0), tf.concat((target, [hparams.eos]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, target_in, target_out, tf.size(source), tf.size(target_in)))
# Proceed to batch and return iterator
NMT模型核心代码
def __init__(self, hparams, iterator, mode):
source, target_in, target_out, source_lengths, target_lengths = iterator.get_next()
# Lookup embeddings
embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize])
encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source)
embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize])
decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in)
# Build and run Encoder LSTM
encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32)
# Build and run Decoder LSTM with TrainingHelper and output projection layer
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
logits = outputs.rnn_output
if mode is 'TRAIN' or mode is 'EVAL': # then calculate loss
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits)
target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype)
self.loss = tf.reduce_sum((crossent * target_weights) / hparams.batch_size)
if mode is 'TRAIN': # then calculate/clip gradients, then optimize model
params = tf.trainable_variables()
gradients = tf.gradients(self.loss, params)
clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm)
optimizer = tf.train.AdamOptimizer(hparams.l_rate)
self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params))
if mode is 'EVAL': # then allow access to input/output tensors to printout
self.src = source
self.tgt_in = target_in
self.tgt_out = target_out
self.logits = logits
用于预测具有重复结构的 language-like 语法的 NMT 模型的核心问题是,无论过去的预测是什么,它都会被激励去简单地预测。由于 TrainingHelper
在每一步都为它提供了正确的先前预测以加快训练速度,因此这人为地产生了模型无法摆脱的局部最小值。
我发现的最佳选择是对损失函数进行加权,例如输出序列中输出不重复的关键点的权重更大。这将激励模型获得正确的结果,而不仅仅是重复
过去的预测。
问题总结
在下面的示例中,我的 NMT 模型具有高损失,因为它正确预测了 target_input
而不是 target_output
。
Targetin : 1 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 9 10 10 10 3 3 10 10 3 10 3 3 10 10 3 9 9 4 4 4 4 4 3 10 3 3 9 9 3 6 6 6 6 6 6 10 9 9 10 10 4 4 4 4 4 4 4 4 4 4 4 4 9 9 9 9 3 3 3 6 6 6 6 6 9 9 10 3 4 4 4 4 4 4 4 4 4 4 4 4 9 9 10 3 10 9 9 3 4 4 4 4 4 4 4 4 4 10 10 4 4 4 4 4 4 4 4 4 4 9 9 10 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 9 3 3 10 6 6 6 6 6 3 9 9 3 3 3 3 3 3 3 10 10 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 5 3 3 3 3 10 10 10 3 9 9 5 10 3 3 3 3 9 9 9 5 10 10 10 10 10 4 4 4 4 3 10 6 6 6 6 6 6 3 5 10 10 10 10 3 9 9 6 6 6 6 6 6 6 6 6 9 9 9 3 3 3 6 6 6 6 6 6 6 6 3 9 9 9 3 3 6 6 6 3 3 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
Targetout : 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 9 10 10 10 3 3 10 10 3 10 3 3 10 10 3 9 9 4 4 4 4 4 3 10 3 3 9 9 3 6 6 6 6 6 6 10 9 9 10 10 4 4 4 4 4 4 4 4 4 4 4 4 9 9 9 9 3 3 3 6 6 6 6 6 9 9 10 3 4 4 4 4 4 4 4 4 4 4 4 4 9 9 10 3 10 9 9 3 4 4 4 4 4 4 4 4 4 10 10 4 4 4 4 4 4 4 4 4 4 9 9 10 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 9 3 3 10 6 6 6 6 6 3 9 9 3 3 3 3 3 3 3 10 10 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 5 3 3 3 3 10 10 10 3 9 9 5 10 3 3 3 3 9 9 9 5 10 10 10 10 10 4 4 4 4 3 10 6 6 6 6 6 6 3 5 10 10 10 10 3 9 9 6 6 6 6 6 6 6 6 6 9 9 9 3 3 3 6 6 6 6 6 6 6 6 3 9 9 9 3 3 6 6 6 3 3 3 3 3 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
Prediction : 3 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 3 3 3 3 3 3 10 3 3 10 3 3 10 3 3 9 3 4 4 4 4 4 3 10 3 3 9 3 3 6 6 6 6 6 6 10 9 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 3 3 3 6 6 6 6 6 9 6 3 3 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 10 9 3 3 4 4 4 4 4 4 4 4 4 3 10 4 4 4 4 4 4 4 4 4 4 9 3 3 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 10 6 6 6 6 6 3 9 3 3 3 3 3 3 3 3 3 3 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 3 3 3 3 3 10 3 3 3 9 3 3 10 3 3 3 3 9 3 9 3 10 3 3 3 3 4 4 4 4 3 10 6 6 6 6 6 6 3 3 10 3 3 3 3 9 3 6 6 6 6 6 6 6 6 6 9 6 9 3 3 3 6 6 6 6 6 6 6 6 3 9 3 9 3 3 6 6 6 3 3 3 3 3 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6
Source : 9 16 4 7 22 22 19 1 12 19 12 18 5 18 9 18 5 8 12 19 19 5 5 19 22 7 12 12 6 19 7 3 20 7 9 14 4 11 20 12 7 1 18 7 7 5 22 9 13 22 20 19 7 19 7 13 7 11 19 20 6 22 18 17 17 1 12 17 23 7 20 1 13 7 11 11 22 7 12 1 13 12 5 5 19 22 5 5 20 1 5 4 12 9 7 12 8 14 18 22 18 12 18 17 19 4 19 12 11 18 5 9 9 5 14 7 11 6 4 17 23 6 4 5 12 6 7 14 4 20 6 8 12 25 4 19 6 1 5 1 5 20 4 18 12 12 1 11 12 1 25 13 18 19 7 12 7 3 4 22 9 9 12 4 8 9 19 9 22 22 19 1 19 7 5 19 4 5 18 11 13 9 4 14 12 13 20 11 12 11 7 6 1 11 19 20 7 22 22 12 22 22 9 3 8 12 11 14 16 4 11 7 11 1 8 5 5 7 18 16 22 19 9 20 4 12 18 7 19 7 1 12 18 17 12 19 4 20 9 9 1 12 5 18 14 17 17 7 4 13 16 14 12 22 12 22 18 9 12 11 3 18 6 20 7 4 20 7 9 1 7 25 13 5 25 14 11 5 20 7 23 12 5 16 19 19 25 19 7 -1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
很明显,预测与 target_input
几乎 100% 匹配,而不是 target_output
,这是应该的(差一)。正在使用 target_output
计算损失和梯度,因此奇怪的是预测匹配到 target_input
。
模型概览
NMT 模型使用源语言中的主要单词序列预测目标语言中的单词序列。这是 Google 翻译背后的框架。由于 NMT 使用耦合 RNN,因此它受到监督并需要标记目标输入和输出。
NMT 使用source
序列、target_input
序列和target_output
序列。在下面的例子中,编码器 RNN(蓝色)使用源输入词产生一个意义向量,它传递给解码器 RNN(红色),解码器使用意义向量产生输出。
在进行新的预测(推理)时,解码器 RNN 使用其自己的先前输出为时间步长中的下一个预测播种。然而,为了改进训练,允许在每个新的时间步使用正确的先前预测为自己播种。这就是为什么 target_input
是训练所必需的。
获取带有源的迭代器的代码,target_in,target_out
def get_batched_iterator(hparams, src_loc, tgt_loc):
if not (os.path.exists('primary.csv') and os.path.exists('secondary.csv')):
utils.integerize_raw_data()
source_dataset = tf.data.TextLineDataset(src_loc)
target_dataset = tf.data.TextLineDataset(tgt_loc)
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.shuffle(hparams.shuffle_buffer_size, seed=hparams.shuffle_seed)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([hparams.sos], target), axis=0), tf.concat((target, [hparams.eos]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, target_in, target_out, tf.size(source), tf.size(target_in)))
# Proceed to batch and return iterator
NMT模型核心代码
def __init__(self, hparams, iterator, mode):
source, target_in, target_out, source_lengths, target_lengths = iterator.get_next()
# Lookup embeddings
embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize])
encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source)
embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize])
decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in)
# Build and run Encoder LSTM
encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32)
# Build and run Decoder LSTM with TrainingHelper and output projection layer
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
logits = outputs.rnn_output
if mode is 'TRAIN' or mode is 'EVAL': # then calculate loss
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits)
target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype)
self.loss = tf.reduce_sum((crossent * target_weights) / hparams.batch_size)
if mode is 'TRAIN': # then calculate/clip gradients, then optimize model
params = tf.trainable_variables()
gradients = tf.gradients(self.loss, params)
clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm)
optimizer = tf.train.AdamOptimizer(hparams.l_rate)
self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params))
if mode is 'EVAL': # then allow access to input/output tensors to printout
self.src = source
self.tgt_in = target_in
self.tgt_out = target_out
self.logits = logits
用于预测具有重复结构的 language-like 语法的 NMT 模型的核心问题是,无论过去的预测是什么,它都会被激励去简单地预测。由于 TrainingHelper
在每一步都为它提供了正确的先前预测以加快训练速度,因此这人为地产生了模型无法摆脱的局部最小值。
我发现的最佳选择是对损失函数进行加权,例如输出序列中输出不重复的关键点的权重更大。这将激励模型获得正确的结果,而不仅仅是重复 过去的预测。