分布式模式下的 TensorFlow 1.4 新数据集 API

TensorFlow 1.4 new Dataset API in Distributed Mode

在 TensorFlow 1.4 的新数据集 API 之前,我使用以下代码在不同工作人员之间创建共享文件名队列:

# queue with the file names that can be shared amongst workers during training
filename_queue = tf.FIFOQueue(100, tf.string, shared_name=shared_name)
enque_op = filename_queue.enqueue_many([tf.train.limit_epochs(file_names, num_epochs)])
close_op = filename_queue.close(cancel_pending_enqueues=True)

# create queue runner and add it to queue runners
qr = tf.train.QueueRunner(filename_queue, [enque_op], close_op,
                          queue_closed_exception_types=(tf.errors.OutOfRangeError, tf.errors.CancelledError))
tf.train.add_queue_runner(qr)

# read example from file
reader = tf.TFRecordReader()
_, example = reader.read(filename_queue)

# parse example
image, ground_truth, example_name = parse_example(example)

此代码使用了队列和队列运行器,它非常丑陋且令人困惑。但它允许选项 shared_name= 在工作人员之间创建共享队列,这样他们就不会处理相同的示例。

TensorFlow 1.4 新版本发布后 input pipelines 变得更加易于使用。所以我想更新我的程序以使用这个新功能。 但是,我在新文档中的任何地方都找不到如何在工作人员之间共享数据集。

这是自动完成的还是不是一项功能?

您可以使用 tf.data.Dataset.shard(参见 documentation)来实现此目的。该文档说明了如何 "shard" 单个文件的元素或(如您的示例)"shard" 文件名。