如何使用 tfslim 记录验证损失和准确性

how to log validation loss and accuracy using tfslim

有什么方法可以在使用 tf-slim 时将验证损失和准确性记录到 tensorboard 中?当我使用keras时,下面的代码可以为我做这件事:

model.fit_generator(generator=train_gen(), validation_data=valid_gen(),...)

然后模型会在每个epoch之后评估validation loss和accuracy,非常方便。但是如何使用 tf-slim 来实现呢?以下步骤使用原始的tensorflow,这不是我想要的:

with tf.Session() as sess:
    for step in range(100000):
        sess.run(train_op, feed_dict={X: X_train, y: y_train})
        if n % batch_size * batches_per_epoch == 0:
            print(sess.run(train_op, feed_dict={X: X_train, y: y_train}))

现在,使用 tf-slim 训练模型的步骤是:

tf.contrib.slim.learning.train(
    train_op=train_op,
    logdir="logs",
    number_of_steps=10000,
    log_every_n_steps = 10,
    save_summaries_secs=1
)

那么如何使用上述 slim 训练程序评估每个 epoch 后的验证损失和准确性?

提前致谢!

此事仍在 TF Slim 回购 (issue #5987) 上讨论。 该框架允许您在训练之后/并行地轻松创建一个 运行 的评估脚本(下面的解决方案 1),但有些人正在推动能够实施 "classic cycle of batch training + validation"(解决方案 2)。


1。在另一个脚本中使用 slim.evaluation

TF Slim 有评估方法,例如slim.evaluation.evaluation_loop() 您可以在另一个脚本中使用(可以 运行 与您的训练并行)来定期加载模型的最新检查点并执行评估。 TF Slim 页面包含一个很好的示例,这样的脚本可能看起来如何:example.

2。提供自定义 train_step_fnslim.learning.train()

讨论发起者提出的一个不完善的解决方案使用了您可以提供给 slim.learning.train() 的自定义训练步骤函数:

"""
Snippet from code by Kevin Malakoff @kmalakoff
https://github.com/tensorflow/tensorflow/issues/5987#issue-192626454
"""
# ...
accuracy_validation = slim.metrics.accuracy(
    tf.argmax(predictions_validation, 1), 
    tf.argmax(labels_validation, 1)) # ... or whatever metrics needed

def train_step_fn(session, *args, **kwargs):
  total_loss, should_stop = train_step(session, *args, **kwargs)

  if train_step_fn.step % FLAGS.validation_check == 0:
    accuracy = session.run(train_step_fn.accuracy_validation)
    print('Step %s - Loss: %.2f Accuracy: %.2f%%' % (str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100))

  # ...

  train_step_fn.step += 1
  return [total_loss, should_stop]

train_step_fn.step = 0
train_step_fn.accuracy_validation = accuracy_validation

slim.learning.train(
  train_op,
  FLAGS.logs_dir,
  train_step_fn=train_step_fn,
  graph=graph,
  number_of_steps=FLAGS.max_steps
)