TensorFlow - 使用批量数据验证准确性
TensorFlow - validate accuracy with batch data
正如教程所说,在每个特定步骤之后,我现在需要使用 'validation' 数据集来验证模型的准确性,并最终使用 'test' 数据集来测试准确性。
示例代码:
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
但我认为它对我的设备来说太大了,可能会发生OOM。
如何向方法'accuracy' 提供批量 validate_feed 并获得总数 'validate_acc'?
(如果我从数据集创建一个迭代器,我如何将 next_batch 送入 'accuracy' 方法?)
谢谢大家的帮助!
通常,您使用类似于以下内容的方法来衡量准确性:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
logits 是您通常传递到 softmax - 交叉熵层的最终特征。上面计算的是给定批次的准确度,而不是整个数据集的准确度。您可以改为执行以下操作:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
total_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))
对测试集中的每个批次执行 "total_correct" 并累积它们:
correct_sum = 0
for batch in data_set:
batch_correct_count = sess.run(total_correct, feed_dict=validate_feed)
correct_sum += batch_correct_count
total_accuracy = correct_sum / data_set.size()
使用上面的公式,您可以通过批量处理数据来正确计算整体精度。当然,这是假设 for 循环在数据集中的互斥批次上运行。您应该避免禁用 iid 采样或从数据集中进行替换采样,这通常用于随机训练。
使用tf.metrics.acccuracy
。它对准确性进行流式计算,这意味着它会为您积累所有必要的信息,并在需要时 return 当前估计的准确性。
有关如何使用它的示例,请参阅 。
正如教程所说,在每个特定步骤之后,我现在需要使用 'validation' 数据集来验证模型的准确性,并最终使用 'test' 数据集来测试准确性。
示例代码:
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
但我认为它对我的设备来说太大了,可能会发生OOM。
如何向方法'accuracy' 提供批量 validate_feed 并获得总数 'validate_acc'?
(如果我从数据集创建一个迭代器,我如何将 next_batch 送入 'accuracy' 方法?)
谢谢大家的帮助!
通常,您使用类似于以下内容的方法来衡量准确性:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
logits 是您通常传递到 softmax - 交叉熵层的最终特征。上面计算的是给定批次的准确度,而不是整个数据集的准确度。您可以改为执行以下操作:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
total_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))
对测试集中的每个批次执行 "total_correct" 并累积它们:
correct_sum = 0
for batch in data_set:
batch_correct_count = sess.run(total_correct, feed_dict=validate_feed)
correct_sum += batch_correct_count
total_accuracy = correct_sum / data_set.size()
使用上面的公式,您可以通过批量处理数据来正确计算整体精度。当然,这是假设 for 循环在数据集中的互斥批次上运行。您应该避免禁用 iid 采样或从数据集中进行替换采样,这通常用于随机训练。
使用tf.metrics.acccuracy
。它对准确性进行流式计算,这意味着它会为您积累所有必要的信息,并在需要时 return 当前估计的准确性。
有关如何使用它的示例,请参阅