如何拆分 Tensorflow 数据集?

How do I split Tensorflow datasets?

我有一个基于一个 .tfrecord 文件的张量流数据集。如何将数据集拆分为测试和训练数据集?例如。 70% 训练和 30% 测试?

编辑:

我的 Tensorflow 版本:1.8 我检查过,可能的副本中没有提到 "split_v" 函数。我也在使用 tfrecord 文件。

您可以使用 Dataset.take()Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)

为了更普遍,我给出了一个使用 70/15/15 train/val/test 拆分的示例,但如果您不需要测试或验证集,只需忽略最后两行。

Take:

Creates a Dataset with at most count elements from this dataset.

Skip:

Creates a Dataset that skips count elements from this dataset.

您可能还想查看 Dataset.shard():

Creates a Dataset that includes only 1/num_shards of this dataset.

这个问题与this one and 类似,恐怕我们还没有得到满意的答案。

  • 使用 take()skip() 需要知道数据集的大小。如果我不知道,或者不想知道怎么办?

  • 使用 shard() 只给出 1 / num_shards 的数据集。如果我想要剩下的怎么办?

我尝试在下面提供一个更好的解决方案,仅在 TensorFlow 2 上进行了测试。假设你已经有一个 shuffled 数据集,然后你可以使用 filter() 将它分成两部分:

import tensorflow as tf

all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
        .shuffle(10, reshuffle_each_iteration=False)

test_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 == 0) \
                    .map(lambda x,y: y)

train_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 != 0) \
                    .map(lambda x,y: y)

for i in test_dataset:
    print(i)

print()

for i in train_dataset:
    print(i)

参数reshuffle_each_iteration=False很重要。它确保原始数据集被洗牌一次,不再洗牌。否则,两个结果集可能会有一些重叠。

使用enumerate()添加索引。

使用 filter(lambda x,y: x % 4 == 0) 从 4 个中抽取 1 个样本。同样,x % 4 != 0 从 4 个样本中抽取 3 个。

使用map(lambda x,y: y)剥离索引并恢复原始样本。

此示例实现了 75/25 拆分。

x % 5 == 0x % 5 != 0 给出了 80/20 的比例。

如果您真的想要 70/30 的比例,x % 10 < 3x % 10 >= 3 应该可以。

更新:

从 TensorFlow 2.0.0 开始,由于 AutoGraph's limitations,上述代码可能会导致一些警告。要消除这些警告,请单独声明所有 lambda 函数:

def is_test(x, y):
    return x % 4 == 0

def is_train(x, y):
    return not is_test(x, y)

recover = lambda x,y: y

test_dataset = all.enumerate() \
                    .filter(is_test) \
                    .map(recover)

train_dataset = all.enumerate() \
                    .filter(is_train) \
                    .map(recover)

这在我的机器上没有发出警告。并且让 is_train() 成为 not is_test() 绝对是一个很好的做法。