如何在 Tensorflow 对象检测 API 中存储最佳模型检查点,而不仅仅是最新的 5 个?

How to store best models checkpoints, not only newest 5, in Tensorflow Object Detection API?

我正在 WIDER FACE 数据集上训练 MobileNet,我遇到了无法解决的问题。 TF 对象检测 API 仅在 train 目录中存储最后 5 个检查点,但我想做的是保存与 mAP 指标相关的最佳模型(或者至少在 [=11] 中保留更多模型=] 删除前的目录)。 例如,今天我在第二天晚上的训练后查看了 Tensorboard,发现夜间模型过度拟合,我无法恢复最佳检查点,因为它已经被删除了。

编辑:我只使用 Tensorflow Object Detection API,默认情况下它会在我指向的训练目录中保存最后 5 个检查点。我寻找一些配置参数或任何可以改变此行为的参数。

有没有人在 code/config 到 set/workaround 的参数中有一些修复?似乎我遗漏了一些东西,很明显实际上重要的是最好的模型,而不是最新的模型(可能会过拟合)。

谢谢!

您可以修改(在您的分支中进行硬编码或打开拉取请求并将选项添加到原型)传递给 tf.train.Saver 的参数:

https://github.com/tensorflow/models/blob/master/research/object_detection/legacy/trainer.py#L376-L377

您可能想要设置:

  • max_to_keep:要保留的最近检查点的最大数量。默认为 5。
  • keep_checkpoint_every_n_hours:保持检查点的频率。默认为 10,000 小时。

您可能对此 Tf github thread that tackles the newest/best checkpoint issue. A user developed his own wrapper, chekmate 感兴趣,大约 tf.Saver 以跟踪最佳检查点。

您可以更改配置。

在run_config.py

class RunConfig(object):
  """This class specifies the configurations for an `Estimator` run."""

  def __init__(self,
           model_dir=None,
           tf_random_seed=None,
           save_summary_steps=100,
           save_checkpoints_steps=_USE_DEFAULT,
           save_checkpoints_secs=_USE_DEFAULT,
           session_config=None,
           keep_checkpoint_max=10,
           keep_checkpoint_every_n_hours=10000,
           log_step_count_steps=100,
           train_distribute=None,
           device_fn=None,
           protocol=None,
           eval_distribute=None,
           experimental_distribute=None):

可以跟进this PR。这里你最好的检查点保存在你的检查点目录中,sub-directory 命名为 best.

你只需要整合 best_saver() 和 (_run_checkpoint_once() 中的方法调用) 里面../object_detection/eval_util.py

此外,它还会为 all_evaluation_metrices 创建一个 json。

为了保存更多的检查点,您可以编写一个简单的python脚本,将检查点及时存储到特定的。

import os
import shutil
import time

while True:
    
    training_file = '/home/vignesh/training' # path of your train directory
    archive_file = 'home/vignesh/training/archive' #path of the directory where you want to save your checkpoints
    files_to_save = []

    for files in os.listdir(training_file):
        
        if files.rsplit('.')[0]=='model':
            
            files_to_save.append(files)

    for files in files_to_save:
        if files in os.listdir(archive_file):
            pass
        else:
            shutil.copy2(training_file+'/'+files,archive_file)
    time.sleep(600) # This will make the script run for every 600 seconds, modify it for your need