如何解码seq2seq的输出?
How to decode the output of seq2seq?
Tensorflow translate.py 示例的代码 here 让我很困惑。复制的代码是:
# This is a greedy decoder - outputs are just argmaxes of output_logits.
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
为什么 argmax
有效?
output_logits
的形状是[bucket_length,batch_size,embedding_size]
对于每个 logit(或:每个单词的激活),他们采用激活值最高的索引。
对于 argmax:查看此页面上的 numpy 示例:https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html
a = array([[0, 1, 2],
[3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])
所以输出的结果是:
- 对于每个单词(bucket_length的长度)
- 获得embedding_size
的最大激活
您应该查看生成的输出数组的形状。您会看到,因为 batch_size 是 1,所以一切正常!
如果这对你有帮助,请告诉我!
Tensorflow translate.py 示例的代码 here 让我很困惑。复制的代码是:
# This is a greedy decoder - outputs are just argmaxes of output_logits.
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
为什么 argmax
有效?
output_logits
的形状是[bucket_length,batch_size,embedding_size]
对于每个 logit(或:每个单词的激活),他们采用激活值最高的索引。
对于 argmax:查看此页面上的 numpy 示例:https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html
a = array([[0, 1, 2],
[3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])
所以输出的结果是:
- 对于每个单词(bucket_length的长度)
- 获得embedding_size 的最大激活
您应该查看生成的输出数组的形状。您会看到,因为 batch_size 是 1,所以一切正常!
如果这对你有帮助,请告诉我!