TensorFlow:在不洗牌的情况下读取队列中的图像

TensorFlow: Reading images in queue without shuffling

我有一个包含 614 张图像的训练集,这些图像已经被打乱。我想按顺序读取图像,每批 5 张。因为我的标签是按相同顺序排列的,所以在读取到批次中时,图像的任何混洗都会导致标签不正确。

这些是我读取图像并将其添加到批次的函数:

# To add files from queue to a batch:
def add_to_batch(image):

    print('Adding to batch')
    image_batch = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)

    # Add to summary
    tf.image_summary('images',image_batch,max_images=30)

    return image_batch

# To read files in queue and process:
def get_batch():

    # Create filename queue of images to read
    filenames = [('/media/jessica/Jessica/TensorFlow/StreetView/training/original/train_%d.png' % i) for i in range(1,614)]
    filename_queue =   tf.train.string_input_producer(filenames,shuffle=False,capacity=614)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)

    # Read and process image
    # Image is 500 x 275:
    my_image = tf.image.decode_png(value)
    my_image_float = tf.cast(my_image,tf.float32)
    my_image_float = tf.reshape(my_image_float,[275,500,4])

    return add_to_batch(my_image_float)

这是我执行预测的函数:

def inference(x):

    < Perform convolution, pooling etc.>

    return y_conv

这是我计算损失和执行优化的函数:

def train_step(y_label,y_conv):

    """ Calculate loss """
    # Cross-entropy
    loss = -tf.reduce_sum(y_label*tf.log(y_conv + 1e-9))

    # Add to summary
    tf.scalar_summary('loss',loss)

    """ Optimisation """
    opt = tf.train.AdamOptimizer().minimize(loss)

    return loss

这是我的主要功能:

def main ():

    # Training
    images = get_batch()
    y_conv = inference(images)
    loss = train_step(y_label,y_conv)

    # To write and merge summaries
    writer = tf.train.SummaryWriter('/media/jessica/Jessica/TensorFlow/StreetView/SummaryLogs/log_5', graph_def=sess.graph_def)
    merged = tf.merge_all_summaries()

    """ Run session """
    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)

    print "Running..."
    for step in range(5):

        # y_1 = <get the correct labels here>

        # Train
        loss_value = sess.run(train_step,feed_dict={y_label:y_1})
        print "Step %d, Loss %g"%(step,loss_value)

        # Save summary
        summary_str = sess.run(merged,feed_dict={y_label:y_1})
        writer.add_summary(summary_str,step)

    print('Finished')

if __name__ == '__main__':
  main()

当我检查我的 image_summary 时,图像似乎没有按顺序排列。或者更确切地说,正在发生的事情是:

图片 1-5:已丢弃,图片 6-10:已读取,图片 11-15:已丢弃,图片 16-20:已读取等

看来我得到了两次批次,扔掉第一个并使用第二个?我尝试了一些补救措施,但似乎无济于事。我觉得我理解调用 images = get_batch()sess.run().

的根本错误

您的 batch 操作是 FIFOQueue,因此每次您使用它的输出时,它都会推进状态。

您的第一个 session.run 调用在 train_step 的计算中使用图像 1-5,您的第二个 session.run 请求计算 image_summary 并拉取图像 5 -6 并在可视化中使用它们。

如果您想在不影响输入状态的情况下可视化事物,将队列值缓存在变量中并使用变量定义摘要作为输入而不是依赖于实时队列会有所帮助。

(image_batch_live,) = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)

image_batch = tf.Variable(
  tf.zeros((batch_size, image_size, image_size, color_channels)),
  trainable=False,
  name="input_values_cached")

advance_batch = tf.assign(image_batch, image_batch_live)

所以现在您的 image_batch 是一个静态值,您可以将其用于计算损失和可视化。在步骤之间,您可以调用 sess.run(advance_batch) 来推进队列。

这种方法的小问题 -- 默认保护程序会将您的 image_batch 变量保存到检查点。如果您曾经更改过批量大小,那么您的检查点恢复将因维度不匹配而失败。要解决此问题,您需要指定要手动恢复的变量列表,并为其余部分指定 运行 个初始值设定项。