OutOfRangeError: tensorflow iterator not reinitializing between runs

OutOfRangeError: tensorflow iterator not reinitializing between runs

我正在使用以下设置通过 tensorflow 微调 Inception 模型,并且正在喂入批次 tf.DatasetAPI。但是,每次我尝试训练此模型(在成功检索任何批次之前)时,我都会收到一个 OutOfRangeError,声称迭代器已耗尽:

Caught OutOfRangeError. Stopping Training. End of sequence
     [[node IteratorGetNext (defined at <ipython-input-8-c768436e70d8>:13)  = IteratorGetNext[output_shapes=[[?,224,224,3], [?,1]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
with tf.Graph().as_default():

作为 get_batch 的结果,我创建了一个函数来输入硬编码的批次,并且它运行和收敛没有任何问题,这让我相信图形和会话代码工作正常。我还测试了 get_batch 函数在会话中迭代,这不会导致错误。我期望的行为是重新开始训练(尤其是重置笔记本等)会在数据集上产生一个新的迭代器。

训练模型的代码:

with tf.Graph().as_default():

    tf.logging.set_verbosity(tf.logging.INFO)
    images, labels = get_batch(filenames=tf_train_record_path+train_file)
    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, ax = inception.inception_v1(images, num_classes=1, is_training=True)

    # Specify the loss function:
    tf.losses.mean_squared_error(labels,logits)
    total_loss = tf.losses.get_total_loss()
    tf.summary.scalar('losses/Total_Loss', total_loss)


     # Specify the optimizer and create the train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    # Run the training:
    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=1)

使用数据集获取批次的代码

def get_batch(filenames):
    dataset = tf.data.TFRecordDataset(filenames=filenames)

    dataset = dataset.map(parse)
    dataset = dataset.batch(2)

    iterator = dataset.make_one_shot_iterator()
    data_X, data_y = iterator.get_next()

    return data_X, data_y 

之前提出的问题类似于我遇到的问题,但是,我没有使用 batch_join 调用。我不知道这是否是 slim.learning.train、从检查点恢复或范围的问题。如有任何帮助,我们将不胜感激!

您的输入管道看起来没问题。问题可能出在损坏的 TFRecords 文件上。您可以使用随机数据尝试您的代码,或者使用您的图像作为 tf.data.Dataset.from_tensor_slices() 的 numpy 数组。 您的解析功能也可能会导致问题。尝试用 sess.run 打印你的 image/label。

我建议使用 Estimator API 作为 train_op。它更方便,很快就会弃用 slim。