TensorFlow - 使用批量数据验证准确性

TensorFlow - validate accuracy with batch data

正如教程所说,在每个特定步骤之后,我现在需要使用 'validation' 数据集来验证模型的准确性,并最终使用 'test' 数据集来测试准确性。

示例代码:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

validate_acc = sess.run(accuracy, feed_dict=validate_feed)

但我认为它对我的设备来说太大了,可能会发生OOM。

如何向方法'accuracy' 提供批量 validate_feed 并获得总数 'validate_acc'?

(如果我从数据集创建一个迭代器,我如何将 next_batch 送入 'accuracy' 方法?)

谢谢大家的帮助!

通常,您使用类似于以下内容的方法来衡量准确性:

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

logits 是您通常传递到 softmax - 交叉熵层的最终特征。上面计算的是给定批次的准确度,而不是整个数据集的准确度。您可以改为执行以下操作:

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
total_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))

对测试集中的每个批次执行 "total_correct" 并累积它们:

  correct_sum = 0
  for batch in data_set:
       batch_correct_count = sess.run(total_correct, feed_dict=validate_feed)
       correct_sum += batch_correct_count

  total_accuracy = correct_sum / data_set.size()

使用上面的公式,您可以通过批量处理数据来正确计算整体精度。当然,这是假设 for 循环在数据集中的互斥批次上运行。您应该避免禁用 iid 采样或从数据集中进行替换采样,这通常用于随机训练。

使用tf.metrics.acccuracy。它对准确性进行流式计算,这意味着它会为您积累所有必要的信息,并在需要时 return 当前估计的准确性。

有关如何使用它的示例,请参阅