Tensorflow,用tf.train.Saver保存了什么?
Tensorflow, what is saved with tf.train.Saver?
我想知道当我在每个训练阶段后使用 tf.train.Saver() 保存我的模型时到底保存了什么。与我习惯使用 Keras 模型的文件相比,该文件似乎有点大。现在我的 RNN 在每次保存时占用 900 MB。有没有办法告诉保存器只保存可训练的参数?我还想要一种只保存部分模型的方法。我知道我可以使用 numpy 格式获取我定义的变量并保存它们,但是当我使用 RNN 类 时,我无法直接访问它们的权重,我查看了代码,没有什么比 get_weights 我能看到。
你可以提供一个变量列表来保存在Saver
构造函数中,即saver=tf.train.Saver(var_list=tf.trainable_variables())
默认保存所有variables._all_saveable_objects()
,如果Saver
不指定var_list
。
即Saver
会默认保存所有全局变量和可保存变量
def _all_saveable_objects():
"""Returns all variables and `SaveableObject`s that must be checkpointed.
Returns:
A list of `Variable` and `SaveableObject` to be checkpointed
"""
# TODO(andreasst): make this function public once things are settled.
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
我想知道当我在每个训练阶段后使用 tf.train.Saver() 保存我的模型时到底保存了什么。与我习惯使用 Keras 模型的文件相比,该文件似乎有点大。现在我的 RNN 在每次保存时占用 900 MB。有没有办法告诉保存器只保存可训练的参数?我还想要一种只保存部分模型的方法。我知道我可以使用 numpy 格式获取我定义的变量并保存它们,但是当我使用 RNN 类 时,我无法直接访问它们的权重,我查看了代码,没有什么比 get_weights 我能看到。
你可以提供一个变量列表来保存在Saver
构造函数中,即saver=tf.train.Saver(var_list=tf.trainable_variables())
默认保存所有variables._all_saveable_objects()
,如果Saver
不指定var_list
。
即Saver
会默认保存所有全局变量和可保存变量
def _all_saveable_objects():
"""Returns all variables and `SaveableObject`s that must be checkpointed.
Returns:
A list of `Variable` and `SaveableObject` to be checkpointed
"""
# TODO(andreasst): make this function public once things are settled.
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))