如何在 TensorFlow 2.x 中保存一个 seq2seq 模型?
How to save a seq2seq model in TensorFlow 2.x?
我正在关注 TensorFlow 文档中的“注意神经网络机器翻译”教程,但不知道如何将模型保存为 SavedModel
文件。
如文档中所示,我可以相当轻松地保存检查点,但据我所知,这在与其他应用程序集成时并不是很有用。有谁知道保存整个“模型”,即使他们没有使用 tf.keras.Model
?
文档:https://www.tensorflow.org/tutorials/text/nmt_with_attention
如here所述,tensorflow Checkpoints 和 SavedModel 中有 2 种保存机制。
如果代码(训练代码或此处的教程)始终可用,那么您只需恢复并使用带有检查点的模型。
为了拥有 SavedModel,您需要将代码重写为
class CustomModule(tf.Module)
并在
小心
When you save a tf.Module, any tf.Variable attributes, tf.function-decorated methods, and tf.Modules found via recursive traversal are saved. (See the Checkpoint tutorial for more about this recursive traversal.) However, any Python attributes, functions, and data are lost. This means that when a tf.function is saved, no Python code is saved.
If no Python code is saved, how does SavedModel know how to restore the function?
Briefly, tf.function works by tracing the Python code to generate a ConcreteFunction (a callable wrapper around tf.Graph). When saving a tf.function, you're really saving the tf.function's cache of ConcreteFunctions.
To learn more about the relationship between tf.function and ConcreteFunctions, see the tf.function guide.
更多信息here
我正在关注 TensorFlow 文档中的“注意神经网络机器翻译”教程,但不知道如何将模型保存为 SavedModel
文件。
如文档中所示,我可以相当轻松地保存检查点,但据我所知,这在与其他应用程序集成时并不是很有用。有谁知道保存整个“模型”,即使他们没有使用 tf.keras.Model
?
文档:https://www.tensorflow.org/tutorials/text/nmt_with_attention
如here所述,tensorflow Checkpoints 和 SavedModel 中有 2 种保存机制。
如果代码(训练代码或此处的教程)始终可用,那么您只需恢复并使用带有检查点的模型。
为了拥有 SavedModel,您需要将代码重写为
class CustomModule(tf.Module)
并在
When you save a tf.Module, any tf.Variable attributes, tf.function-decorated methods, and tf.Modules found via recursive traversal are saved. (See the Checkpoint tutorial for more about this recursive traversal.) However, any Python attributes, functions, and data are lost. This means that when a tf.function is saved, no Python code is saved. If no Python code is saved, how does SavedModel know how to restore the function? Briefly, tf.function works by tracing the Python code to generate a ConcreteFunction (a callable wrapper around tf.Graph). When saving a tf.function, you're really saving the tf.function's cache of ConcreteFunctions. To learn more about the relationship between tf.function and ConcreteFunctions, see the tf.function guide.
更多信息here