TF数据集API:以下顺序是否正确?映射、缓存、随机播放、批处理、重复、预取

TF Dataset API: Is the following sequence correct? map,cache,shuffle,batch,repeat,prefetch

我正在使用这个序列从磁盘读取图像文件并输入到 TF Keras 模型中。

  #Make dataset for training
    dataset_train = tf.data.Dataset.from_tensor_slices((file_ids_training,file_names_training))
    dataset_train = dataset_train.flat_map(lambda file_id,file_name: tf.data.Dataset.from_tensor_slices(
        tuple (tf.py_func(_get_data_for_dataset, [file_id,file_name], [tf.float32,tf.float32]))))
    dataset_train = dataset_train.cache()

    dataset_train= dataset_train.shuffle(buffer_size=train_buffer_size)
    dataset_train= dataset_train.batch(train_batch_size) #Make dataset, shuffle, and create batches
    dataset_train= dataset_train.repeat()
    dataset_train = dataset_train.prefetch(1)
    dataset_train_iterator = dataset_train.make_one_shot_iterator()
    get_train_batch = dataset_train_iterator.get_next()

我对这是否是最佳序列有疑问。 例如repeat 应该在 shuffle() 之后和 batch() 之前吗?,cache() 应该在 batch 之后吗?

这里的答案 建议在批处理之前重复或随机播放。我经常使用的顺序是 (1) shuffle,(2) repeat,(3) map,(4) batch 但它可以根据您的喜好而有所不同。我在 repeat 之前使用 shuffle 以避免模糊 epoch 边界。我在批处理之前使用 map,因为我的映射函数适用于单个示例(而不是一批示例),但您当然可以编写一个矢量化的 map 函数,并期望将批处理视为输入。