BasicDecoder 调用的 Tensorflow 插件 seq2seq 输出 (tfa.seq2seq)

Tensorflow addons seq2seq output of BasicDecoder call (tfa.seq2seq)

基于 tfa.seq2seq 构建一个 seq2seq,基本上像 https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt#train_the_model 中那样工作。我正在查看调用 BasicDecoder 时输出的性质。我创建了一个解码器实例

  decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
                               sampler=greedy_sampler, output_layer=decoder.fc)

以后称之为

  outputs, _, _ = decoder_instance(decoder_embedding_matrix, \ 
     start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)

这里 outputs 是什么:预测概率?

接下来我想做这样的事情

  predicted_logits = predicted_logits[:, -1, :]
  predicted_logits = predicted_logits/temperature
  
  # Sample the output logits to generate token IDs.
  predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
  predicted_ids = tf.squeeze(predicted_ids, axis=-1)
    
  # Convert from token ids to characters
  predicted_chars = chars_from_ids(predicted_ids)

编辑

在我的测试中 outputs 看起来像这样

BasicDecoderOutput(rnn_output=<tf.Tensor: shape=(1, 1, 106), dtype=float32, numpy=
array([[[-1.7647576 ,  1.2142688 ,  2.3475904 ,  0.35890207,
          0.72230023, -0.3587367 , -0.02984604, -1.9962349 ,
          0.510706  , -1.4457364 , -0.43458703, -0.55248725,
         -0.9126631 , -0.5542034 , -1.2392808 , -1.0972862 ,
         -0.7256295 ,  0.02101   , -1.0858598 ,  0.9452345 ,
          0.56474745,  0.2157154 ,  1.6094822 ,  0.6396736 ,
          1.5741622 ,  1.4455014 ,  0.9529134 ,  0.37970737,
         -0.60284877,  0.73455685,  1.0571934 ,  1.3716137 ,
         -1.0882497 ,  1.7738185 ,  1.1919689 ,  0.8144775 ,
          0.84732264,  1.6677057 ,  1.8040668 ,  0.86257285,
          2.0206916 ,  1.3602887 ,  1.2091455 ,  1.318665  ,
         -0.6775206 , -0.9906771 , -0.39923188, -1.0290842 ,
         -1.3546644 , -1.5678416 ,  0.624691  , -1.0316744 ,
          1.2098004 ,  1.4669724 ,  0.9996722 ,  0.12806134,
         -0.42086226, -0.11248919, -0.8277442 ,  0.622267  ,
         -1.6404072 ,  0.2762841 , -0.54035664, -0.6325757 ,
         -0.16794772,  0.8435169 ,  1.1214966 , -1.5629222 ,
          0.27472585,  0.8861834 , -1.7886144 ,  0.56741697,
         -1.9197755 , -1.8073375 , -1.5050163 , -1.7794812 ,
         -0.11308812,  1.3161705 ,  1.027235  ,  1.3830551 ,
         -1.374056  , -1.4779223 ,  0.19962706, -1.6843308 ,
          0.370475  ,  0.8292502 , -1.2990475 , -1.8491654 ,
         -3.4606798 , -0.9822829 , -2.391135  , -3.6944065 ,
         -3.5912528 , -2.4165688 , -2.640759  , -4.0524964 ,
         -3.0878603 , -1.6555822 , -1.2015637 , -1.7716323 ,
          1.7384199 , -2.4340994 , -0.7337967 , -0.88279086,
         -0.85630864, -0.8148002 ]]], dtype=float32)>, sample_id=<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[2]], dtype=int32)>)

使用class GreedyEmbeddingSampler(Sampler):进行推理https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/seq2seq/sampler.py#L559-L650

def sample(self, time, outputs, state):
    """sample for GreedyEmbeddingHelper."""
    del time, state  # unused by sample_fn
    # Outputs are logits, use argmax to get the most probable id
    if not isinstance(outputs, tf.Tensor):
        raise TypeError(
            "Expected outputs to be a single Tensor, got: %s" % type(outputs)
        )
    sample_ids = tf.argmax(outputs, axis=-1, output_type=tf.int32)
    return sample_ids

所以# Outputs are logits, use argmax to get the most probable id

BasicDecoder returns outputs = BasicDecoderOutput(cell_outputs, sample_ids) 是 RNN 单元或最终密集层输出和 logits argmax 的 id。