tensorflow cifar10 从检查点文件恢复训练
tensorflow cifar10 resume training from checkpoint file
在使用 Tensorflow 时,我正在尝试使用检查点文件恢复 CIFAR10 训练。参考其他一些文章,我尝试 tf.train.Saver().restore 但没有成功。有人可以告诉我如何进行吗?
来自 Tensorflow CIFAR10 的代码片段
def train():
# methods to build graph from the cifar10_train.py
global_step = tf.Variable(0, trainable=False)
images, labels = cifar10.distorted_inputs()
logits = cifar10.inference(images)
loss = cifar10.loss(logits, labels)
train_op = cifar10.train(loss, global_step)
saver = tf.train.Saver(tf.all_variables())
summary_op = tf.merge_all_summaries()
init = tf.initialize_all_variables()
sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
sess.run(init)
print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir)
if FLAGS.checkpoint_dir is None:
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
else:
# restoring from the checkpoint file
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
# cur_step prints out well with the checkpointed variable value
cur_step = sess.run(global_step);
print("current step is %s" % cur_step)
for step in xrange(cur_step, FLAGS.max_steps):
start_time = time.time()
# **It stucks at this call **
_, loss_value = sess.run([train_op, loss])
# below same as original
问题似乎出在这一行:
tf.train.start_queue_runners(sess=sess)
...仅在 FLAGS.checkpoint_dir is None
时执行。如果您从检查点恢复,您仍然需要启动队列运行器。
请注意,我建议您在 创建 tf.train.Saver
之后启动队列运行器 (由于代码的发布版本中存在竞争条件),因此更好的结构是:
if FLAGS.checkpoint_dir is not None:
# restoring from the checkpoint file
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
# ...
for step in xrange(cur_step, FLAGS.max_steps):
start_time = time.time()
_, loss_value = sess.run([train_op, loss])
# ...
在使用 Tensorflow 时,我正在尝试使用检查点文件恢复 CIFAR10 训练。参考其他一些文章,我尝试 tf.train.Saver().restore 但没有成功。有人可以告诉我如何进行吗?
来自 Tensorflow CIFAR10 的代码片段
def train():
# methods to build graph from the cifar10_train.py
global_step = tf.Variable(0, trainable=False)
images, labels = cifar10.distorted_inputs()
logits = cifar10.inference(images)
loss = cifar10.loss(logits, labels)
train_op = cifar10.train(loss, global_step)
saver = tf.train.Saver(tf.all_variables())
summary_op = tf.merge_all_summaries()
init = tf.initialize_all_variables()
sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
sess.run(init)
print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir)
if FLAGS.checkpoint_dir is None:
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
else:
# restoring from the checkpoint file
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
# cur_step prints out well with the checkpointed variable value
cur_step = sess.run(global_step);
print("current step is %s" % cur_step)
for step in xrange(cur_step, FLAGS.max_steps):
start_time = time.time()
# **It stucks at this call **
_, loss_value = sess.run([train_op, loss])
# below same as original
问题似乎出在这一行:
tf.train.start_queue_runners(sess=sess)
...仅在 FLAGS.checkpoint_dir is None
时执行。如果您从检查点恢复,您仍然需要启动队列运行器。
请注意,我建议您在 创建 tf.train.Saver
之后启动队列运行器 (由于代码的发布版本中存在竞争条件),因此更好的结构是:
if FLAGS.checkpoint_dir is not None:
# restoring from the checkpoint file
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
# ...
for step in xrange(cur_step, FLAGS.max_steps):
start_time = time.time()
_, loss_value = sess.run([train_op, loss])
# ...