如何用现有的分片 tfrecords 替换 tfds 数据集

How to replace tfds dataset with existing sharded tfrecords

我正在使用使用 tfds 数据集的克隆代码,并希望以尽可能少的修改使其适应一组预先存在的分片 tfrecrods。

具体而言,克隆的代码执行以下操作:

builder = tfds.builder(dataset, data_dir)
builder.download_and_prepare()
...
estimator.train(
        data_lib.build_input_fn(builder, True), max_steps=train_steps
)

在此代码中,'dataset' 是 tfds 数据集的名称(例如 cifar10 或 others)。反而, 我想在已经是分片 tfrecords 形式的外部数据集上进行训练,即:

'train_<shard_id>-<no_samples>.tfrecords'

'val_<shard_id>-<no_samples>.tfrecords'

并驻留在存储桶中(如果该信息有帮助,则在 google 云端)。

我一直在研究 Adding new datasets in TFDS format,但这似乎需要一个完整的管道来从头开始生成 tfrecords,这是不可能的,而且鉴于 tfrecords 已经存在,这似乎是多余的。我确定我缺少对现有 tfrecords 的一些简单改编..

如有任何建议,我们将不胜感激。

阿罗娜,

您的期望是正确的:有一个特殊函数 tf.data.TFRecordDataset 用于处理 tfrecords 中的数据。像这样在你的 input_fn 中使用它:

def input_fn(features, labels, training=True, batch_size=256):
    
    file_paths = [file0, file1]  # pass tfrecords filenames here
    dataset = tf.data.TFRecordDataset(file_paths)

    # Shuffle and repeat if you are in training mode.
    if training:
        dataset = dataset.shuffle(1000).repeat()
    
    return dataset.batch(batch_size)

在 TF 站点阅读更多内容:1 2