如何拆分 Tensorflow 数据集?
How do I split Tensorflow datasets?
我有一个基于一个 .tfrecord 文件的张量流数据集。如何将数据集拆分为测试和训练数据集?例如。 70% 训练和 30% 测试?
编辑:
我的 Tensorflow 版本:1.8
我检查过,可能的副本中没有提到 "split_v" 函数。我也在使用 tfrecord 文件。
您可以使用 Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
为了更普遍,我给出了一个使用 70/15/15 train/val/test 拆分的示例,但如果您不需要测试或验证集,只需忽略最后两行。
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
您可能还想查看 Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.
这个问题与this one and 类似,恐怕我们还没有得到满意的答案。
使用 take()
和 skip()
需要知道数据集的大小。如果我不知道,或者不想知道怎么办?
使用 shard()
只给出 1 / num_shards
的数据集。如果我想要剩下的怎么办?
我尝试在下面提供一个更好的解决方案,仅在 TensorFlow 2 上进行了测试。假设你已经有一个 shuffled 数据集,然后你可以使用 filter()
将它分成两部分:
import tensorflow as tf
all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
.shuffle(10, reshuffle_each_iteration=False)
test_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 == 0) \
.map(lambda x,y: y)
train_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 != 0) \
.map(lambda x,y: y)
for i in test_dataset:
print(i)
print()
for i in train_dataset:
print(i)
参数reshuffle_each_iteration=False
很重要。它确保原始数据集被洗牌一次,不再洗牌。否则,两个结果集可能会有一些重叠。
使用enumerate()
添加索引。
使用 filter(lambda x,y: x % 4 == 0)
从 4 个中抽取 1 个样本。同样,x % 4 != 0
从 4 个样本中抽取 3 个。
使用map(lambda x,y: y)
剥离索引并恢复原始样本。
此示例实现了 75/25 拆分。
x % 5 == 0
和 x % 5 != 0
给出了 80/20 的比例。
如果您真的想要 70/30 的比例,x % 10 < 3
和 x % 10 >= 3
应该可以。
更新:
从 TensorFlow 2.0.0 开始,由于 AutoGraph's limitations,上述代码可能会导致一些警告。要消除这些警告,请单独声明所有 lambda 函数:
def is_test(x, y):
return x % 4 == 0
def is_train(x, y):
return not is_test(x, y)
recover = lambda x,y: y
test_dataset = all.enumerate() \
.filter(is_test) \
.map(recover)
train_dataset = all.enumerate() \
.filter(is_train) \
.map(recover)
这在我的机器上没有发出警告。并且让 is_train()
成为 not is_test()
绝对是一个很好的做法。
我有一个基于一个 .tfrecord 文件的张量流数据集。如何将数据集拆分为测试和训练数据集?例如。 70% 训练和 30% 测试?
编辑:
我的 Tensorflow 版本:1.8 我检查过,可能的副本中没有提到 "split_v" 函数。我也在使用 tfrecord 文件。
您可以使用 Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
为了更普遍,我给出了一个使用 70/15/15 train/val/test 拆分的示例,但如果您不需要测试或验证集,只需忽略最后两行。
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
您可能还想查看 Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.
这个问题与this one and
使用
take()
和skip()
需要知道数据集的大小。如果我不知道,或者不想知道怎么办?使用
shard()
只给出1 / num_shards
的数据集。如果我想要剩下的怎么办?
我尝试在下面提供一个更好的解决方案,仅在 TensorFlow 2 上进行了测试。假设你已经有一个 shuffled 数据集,然后你可以使用 filter()
将它分成两部分:
import tensorflow as tf
all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
.shuffle(10, reshuffle_each_iteration=False)
test_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 == 0) \
.map(lambda x,y: y)
train_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 != 0) \
.map(lambda x,y: y)
for i in test_dataset:
print(i)
print()
for i in train_dataset:
print(i)
参数reshuffle_each_iteration=False
很重要。它确保原始数据集被洗牌一次,不再洗牌。否则,两个结果集可能会有一些重叠。
使用enumerate()
添加索引。
使用 filter(lambda x,y: x % 4 == 0)
从 4 个中抽取 1 个样本。同样,x % 4 != 0
从 4 个样本中抽取 3 个。
使用map(lambda x,y: y)
剥离索引并恢复原始样本。
此示例实现了 75/25 拆分。
x % 5 == 0
和 x % 5 != 0
给出了 80/20 的比例。
如果您真的想要 70/30 的比例,x % 10 < 3
和 x % 10 >= 3
应该可以。
更新:
从 TensorFlow 2.0.0 开始,由于 AutoGraph's limitations,上述代码可能会导致一些警告。要消除这些警告,请单独声明所有 lambda 函数:
def is_test(x, y):
return x % 4 == 0
def is_train(x, y):
return not is_test(x, y)
recover = lambda x,y: y
test_dataset = all.enumerate() \
.filter(is_test) \
.map(recover)
train_dataset = all.enumerate() \
.filter(is_train) \
.map(recover)
这在我的机器上没有发出警告。并且让 is_train()
成为 not is_test()
绝对是一个很好的做法。