Tensorflow 数据集 API - 行为解释

Tensorflow Dataset API - explanation of behavior

使用下面的代码,我想问几个关于下面到底发生了什么的问题。

dataset = tf.data.TFRecordDataset(filepath)
dataset = dataset.map(parse_function, num_parallel_calls=4)
dataset = dataset.repeat()
dataset = dataset.shuffle(1024)
dataset = dataset.batch(16)
iterator = dataset.make_one_shot_iterator()

1.dataset.map(parse_function, num_parallel_calls=4) - 我们在这里加载了多少条记录?多少适合内存或一些固定数量?

2.dataset = dataset.repeat() - 我们到底重复了什么?当前从点 .1 加载了一段数据?如果是这样,是否意味着我们将不再加载其他人?

3.How 随机播放到底有效吗?

4.Can 我们在映射之前使用重复、随机播放和批处理,并处理文件路径而不是单独处理文件?

  1. 您在此处加载整个数据集。在批处理之前应用地图通常不是一个好主意。 Tensorflow 对张量大小有 2GB 的硬限制。 num_parallel_calls 表示并行应用的映射函数数。
  2. dataset.repeat() 没有指定的纪元值将无限期地重复数据集。
  3. Shuffle 将随机打乱具有指定缓冲区值的数据集。为了正确洗牌,通常最好将此值设置为数据集长度,并在批处理之前应用此函数。
  4. tf.data.TFRecordDataset期望文件名作为输入。一般来说,首选顺序是

    dataset = dataset.shuffle(shuffle_buffer).repeat()
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(map_func)
    

看看https://www.tensorflow.org/guide/performance/datasets

  1. DatasetAPI中的数据是延迟加载的,所以要依赖后面的操作。由于洗牌缓冲区的大小,现在您一次加载 1024 个样本。它需要填充洗牌缓冲区。当您从迭代器中获取值时,数据将被延迟加载。
  2. 你重复加载的数据,因为重复是在地图功能之后。这就是为什么建议在解析数据之前进行 shuffle 的原因,因为它对内存更友好。
  3. 随机加载一些数据(取决于随机缓冲区的大小),然后随机播放该数据。
  4. 是的,你可以重复,随机播放然后映射,甚至建议在 performance guide. And there is also function which merges repeat and shuffle together here