如何使用 .ckpt.data 和 .ckpt.index 加载模型
How to load a model using .ckpt.data and .ckpt.index
在代码中,我一直在使用它使用类似incption_v4.ckpt 的.ckpt 来加载模型。我正在尝试使用预训练的 pnesnet 模型,它作为两个单独的文件 .ckpt.data 和 .ckpt.index 出现。谁能告诉我如何从这两个文件加载。
在评估模型的代码中,它使用 dir 的路径作为 checkpoint_path 来加载模型。所以,我试着给出这样的路径,但它不起作用。
def _get_init_fn():
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step.
Returns:
An init function run by the supervisor.
"""
if FLAGS.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then we'll be
# ignoring the checkpoint anyway.
if tf.train.latest_checkpoint(FLAGS.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% FLAGS.train_dir)
return None
exclusions = []
if FLAGS.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
# TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Fine-tuning from %s' % checkpoint_path)
return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=FLAGS.ignore_missing_vars)
以上是从 .ckpt 文件加载的代码。
只需使用模型名称作为 model.ckpt
即可。不必关心 .data
和 .index
部分
在代码中,我一直在使用它使用类似incption_v4.ckpt 的.ckpt 来加载模型。我正在尝试使用预训练的 pnesnet 模型,它作为两个单独的文件 .ckpt.data 和 .ckpt.index 出现。谁能告诉我如何从这两个文件加载。
在评估模型的代码中,它使用 dir 的路径作为 checkpoint_path 来加载模型。所以,我试着给出这样的路径,但它不起作用。
def _get_init_fn():
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step.
Returns:
An init function run by the supervisor.
"""
if FLAGS.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then we'll be
# ignoring the checkpoint anyway.
if tf.train.latest_checkpoint(FLAGS.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% FLAGS.train_dir)
return None
exclusions = []
if FLAGS.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
# TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Fine-tuning from %s' % checkpoint_path)
return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=FLAGS.ignore_missing_vars)
以上是从 .ckpt 文件加载的代码。
只需使用模型名称作为 model.ckpt
即可。不必关心 .data
和 .index
部分