如何解码seq2seq的输出?

How to decode the output of seq2seq?

Tensorflow translate.py 示例的代码 here 让我很困惑。复制的代码是:

  # This is a greedy decoder - outputs are just argmaxes of output_logits.
  outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

为什么 argmax 有效?

output_logits的形状是[bucket_length,batch_size,embedding_size]

对于每个 logit(或:每个单词的激活),他们采用激活值最高的索引。

对于 argmax:查看此页面上的 numpy 示例:https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html

a = array([[0, 1, 2],
       [3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])

所以输出的结果是:

  • 对于每个单词(bucket_length的长度)
    • 获得embedding_size
    • 的最大激活

您应该查看生成的输出数组的形状。您会看到,因为 batch_size 是 1,所以一切正常!

如果这对你有帮助,请告诉我!