使用 TensorFlow 对象检测确定最大批量 API

Determining max batch size with TensorFlow Object Detection API

TF 对象检测 API 默认情况下会占用所有 GPU 内存,因此很难说我可以进一步增加多少批量大小。通常我只是继续增加它,直到出现 CUDA OOM 错误。

另一方面,PyTorch 默认情况下不会获取所有 GPU 内存,因此很容易看出我还剩下多少百分比可以使用,而无需所有的试验和错误。

是否有更好的方法来确定我缺少的 TF 对象检测 API 的批量大小? model_main.py?

allow-growth 标志之类的东西

我一直在查看源代码,但没有找到与此相关的FLAG。

但是,在 https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py 的文件 model_main.py 中 您可以找到以下主要函数定义:

def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
...

想法是以类似的方式修改它,例如以下方式:

config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True

config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)

因此,添加 config_proto 并更改 config 但保持所有其他条件相同。

此外,allow_growth 使程序根据需要使用尽可能多的 GPU 内存。因此,根据您的 GPU,您最终可能会吃掉所有内存。在这种情况下,您可能需要使用

config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9

定义要使用的内存部分。

希望对您有所帮助。

如果您不想修改文件,似乎应该打开一个问题,因为我没有看到任何 FLAG。除非 FLAG

flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
                    'file.')

表示与此相关的内容。但我不这么认为,因为从 model_lib.py 中看来,它与训练、评估和推断配置有关,而不是 GPU 使用配置。