tensorflow - 具有多个 TFRecord 文件的输入管道 + tf.contrib.data.sliding_window_batch()

tensorflow - Input pipeline with multiple TFRecord files + tf.contrib.data.sliding_window_batch()

我有多个 TFRecord 文件,它们都包含我的数据的特定时间范围。包含的数据点在每个文件内是连续的,但跨文件不连续。作为输入管道的一部分,我使用 tf.contrib.data.sliding_window_batch 来处理 window 数据点,如下所示:

filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.TFRecordDataset(filenames)

dataset = dataset.map(parser_fn, num_parallel_calls=6)
dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)

如何防止 window 跨越不同文件的数据点?

使用 tf.Dataset.filter(predicate) 解决了它。

filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.TFRecordDataset(filenames)

dataset = dataset.map(parser_fn, num_parallel_calls=6)
dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
dataset = dataset.filter(lambda x: tf.equal(x['timeframe'][0], x['timeframe'][-1]))
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)

另一种方法是在每个文件上独立创建批处理,interleave 结果:

def interleave_fn(filename):
  dataset = dataset.map(parser_fn, num_parallel_calls=6)
  dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
  dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window

filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(interleave_fn, num_parallel_calls=...)
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=1000000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)

这可能更高效,因为它绕过了过滤步骤。