如何在 tensorflow 对象检测中生成用于训练和推理的 .ckpt 文件 api

How to generate .ckpt file for training and inference in tensorflow object detection api

我正在尝试使用 TensorFlow 对象检测 API 使用来自 TensorFlow zoo hear. I found folder for pre-trained checkpoint hear 的 EfficientDet D3 模型在第 2 节预训练检查点下进行对象检测。

我需要在 .config 文件中指定检查点的路径以进行训练。但是我找不到 .ckpt 文件,不是来自模型,也不是来自上面下载的预训练检查点文件夹。

我发现洋红色有类似的问题 hear。但这对我不起作用。如果有人知道如何在 tensorflow 2 中从 model.ckpt-data-00000-of-000001, model.ckpt.index, model.ckpt.meta 生成 model.ckpt 文件然后告诉我它可能会解决我的问题

我正在使用来自 google colab

的 TensorFlow 2

编辑 1:我从 TensorFlow model zoo 下载模型。它具有波纹管结构。

在路径 /content/models/research/object_detection/EfficientDet_D3/ # this is model dir

efficientdet_d3_coco17_tpu-32/
efficientdet_d3_coco17_tpu-32/checkpoint/
efficientdet_d3_coco17_tpu-32/checkpoint/ckpt-0.data-00000-of-00001
efficientdet_d3_coco17_tpu-32/checkpoint/checkpoint
efficientdet_d3_coco17_tpu-32/checkpoint/ckpt-0.index
efficientdet_d3_coco17_tpu-32/pipeline.config
efficientdet_d3_coco17_tpu-32/saved_model/
efficientdet_d3_coco17_tpu-32/saved_model/saved_model.pb
efficientdet_d3_coco17_tpu-32/saved_model/assets/
efficientdet_d3_coco17_tpu-32/saved_model/variables/
efficientdet_d3_coco17_tpu-32/saved_model/variables/variables.data-00000-of-00001
efficientdet_d3_coco17_tpu-32/saved_model/variables/variables.index

我还使用 this link from EfficientDet readme in github 下载了检查点。它的结构如下所示。

在路径/content/models/research/object_detection/efficientdet-d3/ # this is checkpoint dir

efficientdet-d3/
efficientdet-d3/model.meta
efficientdet-d3/d3_coco_val_softnms.txt
efficientdet-d3/d3_coco_test-dev2017_softnms.txt
efficientdet-d3/model.index
efficientdet-d3/detections_test-dev2017_d3_results.zip
efficientdet-d3/checkpoint
efficientdet-d3/model.data-00000-of-00001

我在 pipeline.config 中指定了 .ckpt 的路径,如下所示。

fine_tune_checkpoint: "/content/models/research/object_detection/efficientdet-d3/model.ckpt"

但它似乎不正确,因为我收到以下错误消息。

Traceback (most recent call last):
  File "model_main_tf2.py", line 113, in <module>
    tf.compat.v1.app.run()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "model_main_tf2.py", line 110, in main
    record_summaries=FLAGS.record_summaries)
  File "/root/.local/lib/python3.6/site-packages/object_detection-0.1-py3.6.egg/object_detection/model_lib_v2.py", line 569, in train_loop
    unpad_groundtruth_tensors)
  File "/root/.local/lib/python3.6/site-packages/object_detection-0.1-py3.6.egg/object_detection/model_lib_v2.py", line 345, in load_fine_tune_checkpoint
    if not is_object_based_checkpoint(checkpoint_path):
  File "/root/.local/lib/python3.6/site-packages/object_detection-0.1-py3.6.egg/object_detection/model_lib_v2.py", line 308, in is_object_based_checkpoint
    var_names = [var[0] for var in tf.train.list_variables(checkpoint_path)]
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 98, in list_variables
    reader = load_checkpoint(ckpt_dir_or_file)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 67, in load_checkpoint
    return py_checkpoint_reader.NewCheckpointReader(filename)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/py_checkpoint_reader.py", line 99, in NewCheckpointReader
    error_translator(e)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/py_checkpoint_reader.py", line 35, in error_translator
    raise errors_impl.NotFoundError(None, None, error_message)
tensorflow.python.framework.errors_impl.NotFoundError: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for /content/models/research/object_detection/efficientdet-d3/model.ckpt

通常当您下载一个预训练模型时,您有 7 个文件。

  • saved_model,Tensorflow保存格式的模型(https://www.tensorflow.org/guide/saved_model)
  • frozen_inference_graph,只能用于推理的模型,所有权重都被冻结,无法再训练
  • 三个检查点文件。 .meta 所有元数据,指向正确检查点的索引和一个或多个数据文件。
  • pipeline.config 包含用于先前训练的配置。
  • 最后一个 检查点 文件包含这些行:
    model_checkpoint_path: "model.ckpt"
    all_model_checkpoint_paths: "model.ckpt"

总之,您要查找的.ckpt 并不真正存在,它只是4 个检查点文件的集合。要使用它,只需放入您的配置文件:

  fine_tune_checkpoint: ".../efficientnet/models/model.ckpt"

关于检查点的 Tensorflow 文档:https://www.tensorflow.org/guide/checkpoint