如何控制张量流估算器保留的检查点数量?

How to control amount of checkpoint kept by tensorflow estimator?

我注意到新的 Estimator API 会在训练期间自动保存检查点,并在训练中断时自动从最后一个检查点重新开始。不幸的是,它似乎只保留了最后 5 个检查点。

你知道如何控制训练过程中保留的检查点数量吗?

Tensorflowtf.estimator.Estimator takes config as an optional argument, which can be a tf.estimator.RunConfig配置运行时的对象settings.You可以实现如下:

# Change maximum number checkpoints to 25
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(keep_checkpoint_max=25)

# Build your estimator
estimator = tf.estimator.Estimator(model_fn,
                                   model_dir=job_dir,
                                   config=run_config,
                                   params=None)

config 参数在扩展 estimator.Estimator 的所有 类(DNNClassifierDNNLinearCombinedClassifierLinearClassifier 等)中可用。

作为旁注,我想补充一点,在 TensorfFlow2 中,情况稍微简单一些。要保留一定数量的检查点文件,您可以修改 model_main_tf2.py 源代码。首先,您可以添加并定义一个整数标志为

# Keep last 25 checkpoints
flags.DEFINE_integer('checkpoint_max_to_keep', 25,
                     'Integer defining how many checkpoint files to keep.')

然后在对 model_lib_v2.train_loop 的调用中使用此预定义值:

# Ensure training loop keeps last 25 checkpoints
model_lib_v2.train_loop(...,
                        checkpoint_max_to_keep=FLAGS.checkpoint_max_to_keep,
                        ...)

上面的符号...表示model_lib_v2.train_loop的其他选项。