TFRecords 和数据集 API 甚至是小批量

TFRecords and Dataset API for even minibatches

我有大约 310 万条记录,它们被分成两个 TFRecords 文件。一个包含正 类 (~217K),另一个包含负 类 (~2.9MM)。我正在尝试使用数据集 API 以每批次有 50/50 拆分的方式交错记录。为了用完所有数据,我想重复正面示例,以便使用所有负面示例。

现在最终发生的是,它开始时是偶数,但是当正记录 运行 出来时,只有负记录出现在批次中。

如果文件名是 train_pos.tfrecords,我相信这可以在我下面的代码中通过添加 .repeat() 来解决,但是,我不知道如何修改 _get_files() 函数这样做。我认为这可能是我缺少的简单答案?

files = tf.data.Dataset.list_files("train_*.tfrecords")       
def _get_files(x):
    return tf.data.TFRecordDataset(x).shuffle(buffer_size=10000)

dataset = files.apply(tf.contrib.data.parallel_interleave(
    lambda x: _get_files(x), cycle_length=2))\
    .batch(self.batch_size)\
    .map(_parse_line, num_parallel_calls=6)\
    .repeat(1)\
    .prefetch(2)

您可以通过使用相关的 TF 记录调用两次 tf.data.Dataset 来创建两个数据集:

files1 = tf.data.Dataset.list_files(...)
files2 = tf.data.Dataset.list_files(...)

并用repeat(-1)使两个数据集取之不尽用之不竭。 然后,您可以使用两个批处理数据集的输出并将它们连接起来以获得平衡的批处理。