Tensorflow 教程:输入管道中的重复洗牌

Tensorflow Tutorial: Duplicated Shuffling in the Input Pipeline

Tensorflow reading data tutorial中给出了一个示例输入管道。在该管道中,数据在 string_input_producershuffle batch generator 中被洗牌两次。这是代码:

def input_pipeline(filenames, batch_size, num_epochs=None):
  # Fist shuffle in the input pipeline
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)

  example, label = read_my_file_format(filename_queue)
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  # Second shuffle as part of the batching. 
  # Requiring min_after_dequeue preloaded images
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)

  return example_batch, label_batch

第二次洗牌有什么用吗? shuffle 批处理生成器的缺点是 min_after_dequeue 个示例总是预先加载到内存中以允许有用的 shuffle。我确实有内存消耗非常大的图像数据。这就是为什么我考虑改用 normal batch generator 的原因。将数据洗牌两次有什么好处吗?

编辑:附加问题,为什么 string_input_producer 仅使用默认容量 32 进行初始化?将 batch_size 的倍数作为容量不是很有利吗?

是 - 这是一种常见模式,并且以最一般的方式显示。 string_input_producer 打乱了数据文件的读取顺序。为了提高效率,每个数据文件通常包含许多示例。 (读取一百万个小文件很慢;最好读取 1000 个大文件,每个文件 1000 个示例。)

因此,文件中的示例被读入一个洗牌队列,在那里它们以更细的粒度进行洗牌,因此来自同一文件的示例并不总是以相同的顺序训练,并得到混合跨输入文件。

有关详细信息,请参阅 Getting good mixing with many input datafiles in tensorflow

如果您的每个文件仅包含一个输入示例,则您无需多次随机播放,只需 string_input_producer 即可摆脱困境,但请注意,您仍然可能会受益于拥有这样的队列阅读后保留几张图片,以便您可以重叠网络的输入和训练。 batchshuffle_batchqueue_runner 将 运行 在单独的线程中,确保 I/O 在后台发生并且图像始终可用训练。当然,创建用于训练的小批量通常对速度有好处。

两种洗牌服务于不同的目的并且洗牌不同的东西:

  • tf.train.string_input_producer 随机播放:布尔值。如果为真,字符串将在每个时期内随机洗牌。。因此,如果您有几个文件 ['file1', 'file2', ..., 'filen'],这会从该列表中随机选择一个文件。如果为false,文件一个接一个。
  • tf.train.shuffle_batch 通过随机打乱张量来创建批次。 所以它从你的队列 read_my_file_format 中取出 batch_size 个张量并打乱它们。

因为两次洗牌做不同的事情,所以将数据洗牌两次是有好处的。即使您使用一批 256 张图像,并且每张图像都是 256x256 像素,您也将消耗不到 100 Mb 的内存。如果在某个时候您会发现内存问题,您可以尝试减小批量大小。

关于默认容量 - 它是 。让它大于 batch_size 并确保它在训练期间永远不会为空是有意义的。

为了回答附加问题,string_input_producer returns 一个包含 文件名称的队列 包含样本,而不是样本本身。然后 shuffle_batch 使用此文件名来加载数据。因此,加载的样本数与 shuffle_batch 函数的 capacity 参数有关,而不是 string_input_producer