如何控制张量流估算器保留的检查点数量?
How to control amount of checkpoint kept by tensorflow estimator?
我注意到新的 Estimator API 会在训练期间自动保存检查点,并在训练中断时自动从最后一个检查点重新开始。不幸的是,它似乎只保留了最后 5 个检查点。
你知道如何控制训练过程中保留的检查点数量吗?
Tensorflowtf.estimator.Estimator takes config
as an optional argument, which can be a tf.estimator.RunConfig配置运行时的对象settings.You可以实现如下:
# Change maximum number checkpoints to 25
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(keep_checkpoint_max=25)
# Build your estimator
estimator = tf.estimator.Estimator(model_fn,
model_dir=job_dir,
config=run_config,
params=None)
config
参数在扩展 estimator.Estimator
的所有 类(DNNClassifier
、DNNLinearCombinedClassifier
、LinearClassifier
等)中可用。
作为旁注,我想补充一点,在 TensorfFlow2 中,情况稍微简单一些。要保留一定数量的检查点文件,您可以修改 model_main_tf2.py
源代码。首先,您可以添加并定义一个整数标志为
# Keep last 25 checkpoints
flags.DEFINE_integer('checkpoint_max_to_keep', 25,
'Integer defining how many checkpoint files to keep.')
然后在对 model_lib_v2.train_loop
的调用中使用此预定义值:
# Ensure training loop keeps last 25 checkpoints
model_lib_v2.train_loop(...,
checkpoint_max_to_keep=FLAGS.checkpoint_max_to_keep,
...)
上面的符号...
表示model_lib_v2.train_loop
的其他选项。
我注意到新的 Estimator API 会在训练期间自动保存检查点,并在训练中断时自动从最后一个检查点重新开始。不幸的是,它似乎只保留了最后 5 个检查点。
你知道如何控制训练过程中保留的检查点数量吗?
Tensorflowtf.estimator.Estimator takes config
as an optional argument, which can be a tf.estimator.RunConfig配置运行时的对象settings.You可以实现如下:
# Change maximum number checkpoints to 25
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(keep_checkpoint_max=25)
# Build your estimator
estimator = tf.estimator.Estimator(model_fn,
model_dir=job_dir,
config=run_config,
params=None)
config
参数在扩展 estimator.Estimator
的所有 类(DNNClassifier
、DNNLinearCombinedClassifier
、LinearClassifier
等)中可用。
作为旁注,我想补充一点,在 TensorfFlow2 中,情况稍微简单一些。要保留一定数量的检查点文件,您可以修改 model_main_tf2.py
源代码。首先,您可以添加并定义一个整数标志为
# Keep last 25 checkpoints
flags.DEFINE_integer('checkpoint_max_to_keep', 25,
'Integer defining how many checkpoint files to keep.')
然后在对 model_lib_v2.train_loop
的调用中使用此预定义值:
# Ensure training loop keeps last 25 checkpoints
model_lib_v2.train_loop(...,
checkpoint_max_to_keep=FLAGS.checkpoint_max_to_keep,
...)
上面的符号...
表示model_lib_v2.train_loop
的其他选项。