TensorFlow:数据集的多线程拆包

TensorFlow: multithreaded unbatching of datasets

我正在使用 TensorFlow 2.0 测试版。我有一个 TensorFlow Dataset,其中每个元素都是一批特征列:一个张量元组,其中每个都有 batch_size 记录的特定特征的值。我需要将这些记录展平以序列化为 TFRecords,我想使用 TensorFlow Dataset 函数来完成。扁平化记录不需要以确定的顺序生成。

下面是一些示例代码,展示了我正在努力完成的工作:

batch_size = 100
num_batches = 10
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.unbatch()

问题是我尝试这样做的方法要么不起作用,要么形成主要瓶颈,因为它们是单线程的。以下是其中一些方法:

  1. unbatch - 单线程,太慢
  2. interleave/flat_map - flat_map 不接受张量元组 - "takes 2 positional arguments but" [num_features] "were given"
  3. interleave/带有 py_function 的自定义函数 - 不起作用,因为 py_function 不能 return Dataset
  4. interleave/没有 py_function 的自定义函数 - 不起作用,因为在图形模式下,无法迭代张量

我需要将 unbatch 替换为某种将批次分发到多个线程的方法,这些线程独立地取消批处理它们,然后交错来自不同线程的结果。有什么想法吗?

这是我最终找到的版本,使用 interleavefrom_tensor_slices:

batch_size = 100
num_batches = 10
num_threads = 4
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads)