Tensorflow 调用之间的数据变化

Data changing between calls in Tensorflow

我对 Tensorflow MNIST 教程做了一点改动。 原始代码(fully_connected_feed.py,第 194-202 行):

checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=global_step)
#Evaluate against the training set.
print('Training Data Eval:')
do_eval(sess, 
        eval_correct, 
        images_placeholder,
        labels_placeholder,
        data_sets.train)

我简单加了一个评价:

checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=global_step)
print('Something strange:')
do_eval(sess, eval_correct, images_placeholder,labels_placeholder,
        data_sets.train)
#Evaluate against the training set.
print('Training Data Eval:')
do_eval(sess, 
        eval_correct, 
        images_placeholder,
        labels_placeholder,
        data_sets.train)

此评估的结果很接近,但不相同(数字因发布而异):

Something strange:
  Num examples: 55000  Num correct: 49218  Precision @ 1: 0.8949
Training Data Eval:
  Num examples: 55000  Num correct: 49324  Precision @ 1: 0.8968

怎么可能? UPD:将 link 添加到 tensorflow github: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnist

do_eval() 函数实际上确实有副作用,因为 data_sets.train 是一个 stateful DataSet object that contains a current _index_in_epoch member, which is advanced on each call to DataSet.next_batch() (i.e. in fill_feed_dict())。

就其本身而言,这个事实不足以给出 non-deterministic 结果,但是关于 DataSet.next_batch() 的另外两个细节导致了 non-determinism:

  1. 每开始一个新的epoch,例子是randomly shuffled.

  2. 当数据集reaches the end of an epoch时,数据集重置为开始,最后num_examples % batch_size个样本被丢弃。由于随机改组,每次都会丢弃随机 sub-batch 个示例,从而导致 non-deterministic 个结果。

鉴于代码的结构方式(DataSet 在训练和测试之间共享),使代码具有确定性是很棘手的。 DataSet class 的文档很少,但这种行为令人惊讶,所以我会考虑 filing a GitHub issue 这个问题。