在tensorflow中用tf.train.string_input_producer确定纪元号

Determining the Epoch Number with tf.train.string_input_producer in tensorflow

我对 tf.train.string_input_producer 的工作原理有些怀疑。所以假设我将 filename_list 作为输入参数提供给 string_input_producer。然后,根据文档 https://www.tensorflow.org/programmers_guide/reading_data,这将创建一个 FIFOQueue,我可以在其中设置纪元号、打乱文件名等。因此,就我而言,我有 4 个文件名("db1.tfrecords"、"db2.tfrecords"...)。我使用 tf.train.batch 来为网络提供一批图像。此外,每个 file_name/database 包含一个人的一组图像。第二个数据库是给第二个人的,依此类推。到目前为止,我有以下代码:

tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"),
                          (common + "P21_db.tfrecords")]

filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()

key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized_example,
    # Defaults are not specified since both keys are required.
    features={
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'image_raw': tf.FixedLenFeature([], tf.string),
        'annotation_raw': tf.FixedLenFeature([], tf.string)
    })

image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)

image = tf.reshape(image, [height, width, 3])

annotation = tf.cast(features['annotation_raw'], tf.string)

min_after_dequeue = 100
num_threads = 4
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.batch([annotation, image],
                                                        shapes=[[], [112, 112, 3]],
                                                        batch_size=batch_size,
                                                        capacity=capacity,
                                                        num_threads=num_threads)

最后,当我试图在自动编码器的输出端查看重建图像时,我首先从第一个数据库中获取了图像,然后我开始从第二个数据库中查看图像,依此类推。

我的问题:我怎么知道我是否在同一个时代?如果我在理智的时代,我如何合并我拥有的所有 file_names 中的一批图像?

最后,我尝试通过评估 Session 中的局部变量来打印出纪元的值,如下所示:

epoch_var = tf.local_variables()[0]

然后:

with tf.Session() as sess:
    print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y.

非常感谢任何帮助!!

所以我发现使用 tf.train.shuffle_batch_join 解决了我的问题,因为它开始随机播放来自不同数据集的图像。换句话说,每个批次现在都包含来自所有 datasets/file_names 的图像。这是一个例子:

def read_my_file_format(filename_queue):
    reader = tf.TFRecordReader()
    key, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string),
            'annotation_raw': tf.FixedLenFeature([], tf.string)
        })

    # This is how we create one example, that is, extract one example from the database.
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    # The height and the weights are used to
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)

    # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the
    # height and the weight to restore the original image back.
    image = tf.reshape(image, [height, width, 3])

    annotation = tf.cast(features['annotation_raw'], tf.string)
    return annotation, image

def input_pipeline(filenames, batch_size, num_threads, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epoch, shuffle=False,
                                                    name='queue')
    # Therefore, Note that here we have created num_threads readers to read from the filename_queue.
    example_list = [read_my_file_format(filename_queue=filename_queue) for _ in range(num_threads)]
    min_after_dequeue = 100
    capacity = min_after_dequeue + num_threads * batch_size
    label_batch, images_batch = tf.train.shuffle_batch_join(example_list,
                                                            shapes=[[], [112, 112, 3]],
                                                            batch_size=batch_size,
                                                            capacity=capacity,
                                                            min_after_dequeue=min_after_dequeue)
    return label_batch, images_batch, example_list

label_batch, images_batch, input_ann_img = \
    input_pipeline(tfrecords_filename_seq, batch_size, num_threads, num_epochs=num_epoch)

现在这将创建多个 reader 以从 FIFOQueue 读取,并且在每个 reader 之后将有一个不同的解码器。最后,在对图像进行解码后,它们将输入到另一个 Queue 中,该 Queue 是在调用 tf.train.shuffle_batch_join 以向网络提供一批图像后创建的。