你如何在 `slim.learning.train()` 期间 运行 验证循环

How do you run a validation loop during `slim.learning.train()`

我正在查看此答案以了解 运行训练期间的评估指标:

似乎覆盖 train_step_fn=train_step_fn 是合理的方法。但我想要 运行 一个验证循环,而不是评估。我的图表是这样的:

with tf.Graph().as_default():

    train_dataset = slim.dataset.Dataset(data_sources= "train_*.tfrecord")
    train_images, _, train_labels = load_batch(train_dataset, 
                batch_size=mini_batch_size,
                is_training=True)

    val_dataset = slim.dataset.Dataset(data_sources= "validation_*.tfrecord")
    val_images, _, val_labels = load_batch(val_dataset, 
                batch_size=mini_batch_size,
                is_training=False)


    with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=0.0005)):
        net, end_points = vgg.vgg_16(train_images, 
                                      num_classes=10,
                                      is_training=is_training)
    predictions = tf.nn.softmax(net)
    labels = train_labels

    ...

    init_fn = slim.assign_from_checkpoint_fn(
        checkpoint_path,
        slim.get_variables_to_restore(exclude=['vgg_16/fc8']),
        ignore_missing_vars=True
        )     

    final_loss = slim.learning.train(train_op, TRAIN_LOG, 
                        train_step_fn=train_step_fn,
                        init_fn=init_fn,
                        global_step=global_step,
                        number_of_steps=steps,
                        save_summaries_secs=60,
                        save_interval_secs=600,
                        session_config=sess_config,
                      )

我想添加类似这样的内容来针对网络的当前权重使用小批量进行验证循环

    def validate_on_checkpoint(sess, *args, **kwargs ):
        loss,mean,stddev = sess.run([val_loss, val_rms_mean, val_rms_stddev], 
                        feed_dict={images: val_images, 
                                   labels: val_labels, 
                                   is_training: is_training })
        validation_writer = tf.train.SummaryWriter(LOG_DIR + '/validation')                                              
        validation_writer.add_summary(loss, global_step)
        validation_writer.add_summary(mean, global_step)
        validation_writer.add_summary(stddev, global_step)


    def train_step_fn(sess, *args, **kwargs):
        total_loss, should_stop = train_step(sess, *args, **kwargs)

        if train_step_fn.step % FLAGS.validation_every_n_step == 0:
            validate_on_checkpoint(sess, *args, **kwargs )

        train_step_fn.step += 1
        return [total_loss, should_stop]   

但是我得到一个错误=Graph is finalized and cannot be modified.

从概念上讲,我不确定应该如何添加它。 training 循环需要网络的梯度、丢失和权重更新,但 validation 循环跳过所有这些。如果我尝试修改图形,我会不断收到 Graph is finalized and cannot be modified. 的变化;如果我使用 if is_training: else: 方法,我会收到 XXX is not defined 的变化

我从其他几个 Whosebug 答案中找到了一种方法来完成这项工作。以下是基础知识:

1) 获取 trainvalidation 数据集的输入和标签

x_train, y_train = produce_batch(320)
x_validation, y_validation = produce_batch(320)

2) 使用 reuse=Truetrainvalidation 循环之间重用模型权重。这是一种方法:

  with tf.variable_scope("model") as scope:
    # Make the model, reuse weights for validation batches
    predictions, nodes = regression_model(inputs, is_training=True)
    scope.reuse_variables()
    val_predictions, _ = regression_model(val_inputs, is_training=False)

3) 定义你的损失,将你的 validation 损失放在不同的集合中,这样它就不会添加到 tf.losses.get_losses()

中的 train 损失中
  loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)
  total_loss = tf.losses.get_total_loss()

  val_loss = tf.losses.mean_squared_error(labels=val_targets, predictions=val_predictions,
                                          loss_collection="validation"
                                         )

4) 根据需要定义一个train_step_fn()来触发验证循环

VALIDATION_INTERVAL = 1000 . # validate every 1000 steps
# slim.learning.train(train_step_fn=train_step_fn)
def train_step_fn(sess, train_op, global_step, train_step_kwargs):
  """
  slim.learning.train_step():
    train_step_kwargs = {summary_writer:, should_log:, should_stop:}
  """
  train_step_fn.step += 1  # or use global_step.eval(session=sess)

  # calc training losses
  total_loss, should_stop = slim.learning.train_step(sess, train_op, global_step, train_step_kwargs)


  # validate on interval
  if train_step_fn.step % VALIDATION_INTERVAL == 0:
    validiate_loss, validation_delta = sess.run([val_loss, summary_validation_delta])
    print(">> global step {}:    train={}   validation={}  delta={}".format(train_step_fn.step, 
                        total_loss, validiate_loss, validiate_loss-total_loss))


  return [total_loss, should_stop]
train_step_fn.step = 0

5) 将 train_step_fn() 添加到你的训练循环中

  # Run the training inside a session.
  final_loss = slim.learning.train(
      train_op,
      train_step_fn=train_step_fn,
      ...
      )

在此查看完整结果 Colaboratory notebook