BasicDecoder 调用的 Tensorflow 插件 seq2seq 输出 (tfa.seq2seq)
Tensorflow addons seq2seq output of BasicDecoder call (tfa.seq2seq)
基于 tfa.seq2seq 构建一个 seq2seq,基本上像 https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt#train_the_model 中那样工作。我正在查看调用 BasicDecoder
时输出的性质。我创建了一个解码器实例
decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
sampler=greedy_sampler, output_layer=decoder.fc)
以后称之为
outputs, _, _ = decoder_instance(decoder_embedding_matrix, \
start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)
这里 outputs
是什么:预测概率?
接下来我想做这样的事情
predicted_logits = predicted_logits[:, -1, :]
predicted_logits = predicted_logits/temperature
# Sample the output logits to generate token IDs.
predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids = tf.squeeze(predicted_ids, axis=-1)
# Convert from token ids to characters
predicted_chars = chars_from_ids(predicted_ids)
编辑
在我的测试中 outputs
看起来像这样
BasicDecoderOutput(rnn_output=<tf.Tensor: shape=(1, 1, 106), dtype=float32, numpy=
array([[[-1.7647576 , 1.2142688 , 2.3475904 , 0.35890207,
0.72230023, -0.3587367 , -0.02984604, -1.9962349 ,
0.510706 , -1.4457364 , -0.43458703, -0.55248725,
-0.9126631 , -0.5542034 , -1.2392808 , -1.0972862 ,
-0.7256295 , 0.02101 , -1.0858598 , 0.9452345 ,
0.56474745, 0.2157154 , 1.6094822 , 0.6396736 ,
1.5741622 , 1.4455014 , 0.9529134 , 0.37970737,
-0.60284877, 0.73455685, 1.0571934 , 1.3716137 ,
-1.0882497 , 1.7738185 , 1.1919689 , 0.8144775 ,
0.84732264, 1.6677057 , 1.8040668 , 0.86257285,
2.0206916 , 1.3602887 , 1.2091455 , 1.318665 ,
-0.6775206 , -0.9906771 , -0.39923188, -1.0290842 ,
-1.3546644 , -1.5678416 , 0.624691 , -1.0316744 ,
1.2098004 , 1.4669724 , 0.9996722 , 0.12806134,
-0.42086226, -0.11248919, -0.8277442 , 0.622267 ,
-1.6404072 , 0.2762841 , -0.54035664, -0.6325757 ,
-0.16794772, 0.8435169 , 1.1214966 , -1.5629222 ,
0.27472585, 0.8861834 , -1.7886144 , 0.56741697,
-1.9197755 , -1.8073375 , -1.5050163 , -1.7794812 ,
-0.11308812, 1.3161705 , 1.027235 , 1.3830551 ,
-1.374056 , -1.4779223 , 0.19962706, -1.6843308 ,
0.370475 , 0.8292502 , -1.2990475 , -1.8491654 ,
-3.4606798 , -0.9822829 , -2.391135 , -3.6944065 ,
-3.5912528 , -2.4165688 , -2.640759 , -4.0524964 ,
-3.0878603 , -1.6555822 , -1.2015637 , -1.7716323 ,
1.7384199 , -2.4340994 , -0.7337967 , -0.88279086,
-0.85630864, -0.8148002 ]]], dtype=float32)>, sample_id=<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[2]], dtype=int32)>)
使用class GreedyEmbeddingSampler(Sampler):
进行推理https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/seq2seq/sampler.py#L559-L650
def sample(self, time, outputs, state):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, tf.Tensor):
raise TypeError(
"Expected outputs to be a single Tensor, got: %s" % type(outputs)
)
sample_ids = tf.argmax(outputs, axis=-1, output_type=tf.int32)
return sample_ids
所以# Outputs are logits, use argmax to get the most probable id
BasicDecoder returns outputs = BasicDecoderOutput(cell_outputs, sample_ids)
是 RNN 单元或最终密集层输出和 logits argmax 的 id。
基于 tfa.seq2seq 构建一个 seq2seq,基本上像 https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt#train_the_model 中那样工作。我正在查看调用 BasicDecoder
时输出的性质。我创建了一个解码器实例
decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
sampler=greedy_sampler, output_layer=decoder.fc)
以后称之为
outputs, _, _ = decoder_instance(decoder_embedding_matrix, \
start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)
这里 outputs
是什么:预测概率?
接下来我想做这样的事情
predicted_logits = predicted_logits[:, -1, :]
predicted_logits = predicted_logits/temperature
# Sample the output logits to generate token IDs.
predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids = tf.squeeze(predicted_ids, axis=-1)
# Convert from token ids to characters
predicted_chars = chars_from_ids(predicted_ids)
编辑
在我的测试中 outputs
看起来像这样
BasicDecoderOutput(rnn_output=<tf.Tensor: shape=(1, 1, 106), dtype=float32, numpy=
array([[[-1.7647576 , 1.2142688 , 2.3475904 , 0.35890207,
0.72230023, -0.3587367 , -0.02984604, -1.9962349 ,
0.510706 , -1.4457364 , -0.43458703, -0.55248725,
-0.9126631 , -0.5542034 , -1.2392808 , -1.0972862 ,
-0.7256295 , 0.02101 , -1.0858598 , 0.9452345 ,
0.56474745, 0.2157154 , 1.6094822 , 0.6396736 ,
1.5741622 , 1.4455014 , 0.9529134 , 0.37970737,
-0.60284877, 0.73455685, 1.0571934 , 1.3716137 ,
-1.0882497 , 1.7738185 , 1.1919689 , 0.8144775 ,
0.84732264, 1.6677057 , 1.8040668 , 0.86257285,
2.0206916 , 1.3602887 , 1.2091455 , 1.318665 ,
-0.6775206 , -0.9906771 , -0.39923188, -1.0290842 ,
-1.3546644 , -1.5678416 , 0.624691 , -1.0316744 ,
1.2098004 , 1.4669724 , 0.9996722 , 0.12806134,
-0.42086226, -0.11248919, -0.8277442 , 0.622267 ,
-1.6404072 , 0.2762841 , -0.54035664, -0.6325757 ,
-0.16794772, 0.8435169 , 1.1214966 , -1.5629222 ,
0.27472585, 0.8861834 , -1.7886144 , 0.56741697,
-1.9197755 , -1.8073375 , -1.5050163 , -1.7794812 ,
-0.11308812, 1.3161705 , 1.027235 , 1.3830551 ,
-1.374056 , -1.4779223 , 0.19962706, -1.6843308 ,
0.370475 , 0.8292502 , -1.2990475 , -1.8491654 ,
-3.4606798 , -0.9822829 , -2.391135 , -3.6944065 ,
-3.5912528 , -2.4165688 , -2.640759 , -4.0524964 ,
-3.0878603 , -1.6555822 , -1.2015637 , -1.7716323 ,
1.7384199 , -2.4340994 , -0.7337967 , -0.88279086,
-0.85630864, -0.8148002 ]]], dtype=float32)>, sample_id=<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[2]], dtype=int32)>)
使用class GreedyEmbeddingSampler(Sampler):
进行推理https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/seq2seq/sampler.py#L559-L650
def sample(self, time, outputs, state):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, tf.Tensor):
raise TypeError(
"Expected outputs to be a single Tensor, got: %s" % type(outputs)
)
sample_ids = tf.argmax(outputs, axis=-1, output_type=tf.int32)
return sample_ids
所以# Outputs are logits, use argmax to get the most probable id
BasicDecoder returns outputs = BasicDecoderOutput(cell_outputs, sample_ids)
是 RNN 单元或最终密集层输出和 logits argmax 的 id。