保存 tensorflow 编码器、解码器和注意力

Saving tensorflow encoder, decoder and attention

开始使用编码器和解码器训练简单的 NMT(神经机器翻译),训练在 Colab 上进行,

encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)

然后使用检查点保存模型,

# On loacl machine dir changed to 'training_checkpoints/' to fit the loaction
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

并在训练期间使用

保存
checkpoint.save(file_prefix = checkpoint_prefix)

训练后恢复检查点在 Colab 上工作正常,即使将整个检查点文件夹保存在 Google 驱动器上并再次恢复它们,但是当试图在我的本地机器上恢复它们​​时 return不同的垃圾结果, 使用

训练前开始检查点
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

Colab 笔记本输出:

Input: <start> يلعبون الكرة <end>
Predicted translation: he played soccer . <end> 

本地机器输出:

Input: <start> يلعبون الكرة <end>
Predicted translation: take either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either

Colab tensorflow 版本:1.13.0-rc1

本机tensorflow版本:1.12.0

知道这个问题是因为tensorflow版本不同,如何保存模型而不遇到这个问题?

额外 link 用于 NMT 笔记本 Neural Machine Translation with Attention

TF 仅做出前向兼容性保证:https://www.tensorflow.org/guide/version_compat#compatibility_of_graphs_and_checkpoints 1.13 保存一个 1.12 无法恢复的文件也就不足为奇了。 升级本地机器的 tensorflow?