结合使用 Estimators API 和 tf.data.Dataset 如何加快批处理准备

How to speed up batch preparation when using Estimators API combined with tf.data.Dataset

我想加快使用 Estimator API 和 input_fn 使用 tf.data.Dataset 编写的训练程序。

我的实现需要 2 秒来准备一批数据,然后在 GPU 上运行训练 1 秒,然后重新开始准备一批数据。这真的很低效。

我正在寻找一种方法来异步准备批次并将它们上传到 GPU 以加快训练速度。或者另一种方法是在 input_fn 的调用之间缓存数据集(dataset.cache() 似乎不是一个好的选择,因为必须在每次 input_fn 调用时重新创建数据集)。

这是我的代码的简化版本:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
  if shuffle:
     dataset = dataset.shuffle(buffer_size=len(labels))
  dataset = dataset.map(_post_process,  num_parallel_calls=num_map_threads)
  dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
  dataset = dataset.batch(128)
  dataset = dataset.repeat(epochs) # to iterate over the training set forever
  iterator = dataset.dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) 
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

我注意到 Estimator API 正在积极开发中,在 tensorflow 的 master 分支中 input_fn 已经可以 return 数据集,所以也许我也在问早,此功能尚未准备就绪。但如果是这样,请提供可以跟踪此实施的票证。

使用tf.data.Dataset.cache()确实不是一个好的选择,因为它会将整个数据集缓存到内存中,这需要时间并且可能会溢出你的内存。

方法是在管道的末尾使用 tf.data.Dataset.prefetch(),这将始终确保数据管道包含 buffer_size 个元素。通常最后有 buffer_size = 1 就足够了:

dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)  # prefetch one batch

正如@mrry 在 中所解释的那样,您还可以尝试稍微增加预取批次的数量。

Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.


如果与 GPU 计算相比,您的输入管道仍然很慢,则需要使用 tf.data.Dataset.map()num_parallel_calls 参数增加并行工作的线程数。

要添加到 Olivier 的回答中的几点,主要来自 this post:

    shuffle 之前的
  • repeat 稍微快一些,处于模糊的纪元边界的缺点。这在极少数情况下可能很重要,但我对此表示怀疑。
  • shufflemapping 之前 - 这减少了随机播放缓冲区大小的内存占用,因为它只需要缓冲文件名而不是文件内容。
  • 对我来说,将第三个地图变换应用于 get_next() 的输出而不是数据集更有意义 - 不确定这是否会影响速度。您还可以考虑将其他两个地图调用放在同一个地图调用中以减少调度问题。
  • batching 之前用 repeat 进行实验。可能不会有什么不同,但可能很小。如果您在 shuffle 之前 repeat 如上所述,您将不得不这样做。
  • 如 Olivier 所述,使用 prefetch

修改后的代码:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.repeat(epochs)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(labels))

  def combined_map_fn(*args):
    return _post_process(_read_wav(*args))

  dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
  dataset = dataset.batch(128)
  dataset = dataset.prefetch(1)

  iterator = dataset.dataset.make_one_shot_iterator()
  wavs, labels = iterator.get_next()
  features = {'wav': wavs}
  return features, labels