如何合并两个(或更多)TensorFlow 数据集?

How can I merge two (or more) TensorFlow datasets?

我已经获取了具有 3 个分区的 CelebA 数据集,如下所示

>>> celeba_bldr = tfds.builder('celeb_a')
>>> datasets = celeba_bldr.as_dataset()
>>> datasets.keys()
dict_keys(['test', 'train', 'validation'])

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']

现在,我想将它们全部合并到一个数据集中。例如,我需要将训练和验证结合在一起,或者可能,将它们全部合并在一起,然后根据我自己的不同主题分离标准将它们拆分。有没有办法做到这一点?

我在文档中找不到执行此操作的任何选项 https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset

查看您链接的文档,数据集似乎有 concatenate 方法,所以我假设您可以获得一个联合数据集:

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']

ds = ds_train.concatenate(ds_test).concatenate(ds_valid)

参见:https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#concatenate

我还想提一下,如果您需要连接 多个 数据集(例如,数据集列表),您可以采用更有效的方式:

ds_l = [ds_1, ds_2, ds_3] # list of `Dataset` objects
# 1. create dataset where each element is a `tf.data.Dataset` object
ds = tf.data.Dataset.from_tensor_slice(ds_l)
# 2. extract all elements from datasets and concat them into one dataset
concat_ds = ds.interleave(lambda x: x, cycle_length=1, num_parallel_calls=tf.data.AUTOTUNE)

您也可以使用 flat_map(),但我认为使用 interleave() 进行并行调用会更快。一般来说interleaveflat_map的推广。

如果数据集来自同一个 TFDS 数据集,您可以直接将它们与拆分合并 API:

ds = tfds.load('celeb_a', split='train+test+validation')

或者使用特殊的all拆分:

ds = tfds.load('celeb_a', split='all')

文档:https://www.tensorflow.org/datasets/splits