如何用现有的分片 tfrecords 替换 tfds 数据集
How to replace tfds dataset with existing sharded tfrecords
我正在使用使用 tfds 数据集的克隆代码,并希望以尽可能少的修改使其适应一组预先存在的分片 tfrecrods。
具体而言,克隆的代码执行以下操作:
builder = tfds.builder(dataset, data_dir)
builder.download_and_prepare()
...
estimator.train(
data_lib.build_input_fn(builder, True), max_steps=train_steps
)
在此代码中,'dataset' 是 tfds 数据集的名称(例如 cifar10 或 others)。反而,
我想在已经是分片 tfrecords 形式的外部数据集上进行训练,即:
'train_<shard_id>-<no_samples>.tfrecords'
'val_<shard_id>-<no_samples>.tfrecords'
并驻留在存储桶中(如果该信息有帮助,则在 google 云端)。
我一直在研究 Adding new datasets in TFDS format,但这似乎需要一个完整的管道来从头开始生成 tfrecords,这是不可能的,而且鉴于 tfrecords 已经存在,这似乎是多余的。我确定我缺少对现有 tfrecords 的一些简单改编..
如有任何建议,我们将不胜感激。
阿罗娜,
您的期望是正确的:有一个特殊函数 tf.data.TFRecordDataset
用于处理 tfrecords 中的数据。像这样在你的 input_fn 中使用它:
def input_fn(features, labels, training=True, batch_size=256):
file_paths = [file0, file1] # pass tfrecords filenames here
dataset = tf.data.TFRecordDataset(file_paths)
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
我正在使用使用 tfds 数据集的克隆代码,并希望以尽可能少的修改使其适应一组预先存在的分片 tfrecrods。
具体而言,克隆的代码执行以下操作:
builder = tfds.builder(dataset, data_dir)
builder.download_and_prepare()
...
estimator.train(
data_lib.build_input_fn(builder, True), max_steps=train_steps
)
在此代码中,'dataset' 是 tfds 数据集的名称(例如 cifar10 或 others)。反而, 我想在已经是分片 tfrecords 形式的外部数据集上进行训练,即:
'train_<shard_id>-<no_samples>.tfrecords'
'val_<shard_id>-<no_samples>.tfrecords'
并驻留在存储桶中(如果该信息有帮助,则在 google 云端)。
我一直在研究 Adding new datasets in TFDS format,但这似乎需要一个完整的管道来从头开始生成 tfrecords,这是不可能的,而且鉴于 tfrecords 已经存在,这似乎是多余的。我确定我缺少对现有 tfrecords 的一些简单改编..
如有任何建议,我们将不胜感激。
阿罗娜,
您的期望是正确的:有一个特殊函数 tf.data.TFRecordDataset
用于处理 tfrecords 中的数据。像这样在你的 input_fn 中使用它:
def input_fn(features, labels, training=True, batch_size=256):
file_paths = [file0, file1] # pass tfrecords filenames here
dataset = tf.data.TFRecordDataset(file_paths)
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)