如何从生成器创建固定长度 tf.Dataset?

How to create a fixed length tf.Dataset from generator?

我有一个生成无限量数据的生成器(随机图像裁剪)。我想基于比方说 10,000 个第一个数据点创建一个 tf.Dataset 并将其缓存以使用它们来训练模型?

目前,我有一个生成器需要 1-2 秒来创建每个数据点,这是主要的性能障碍。我必须等待一分钟才能生成一批 64 张图像(preprocessing() 函数非常昂贵,所以我想重用结果)。

ds = tf.Dataset.from_generator() 方法允许我们创建这样的无限数据集。相反,我想使用生成器的 N 个第一个输出创建一个有限数据集并将其缓存如下:

ds = ds.cache().


替代解决方案是不断生成新数据,并在渲染生成器时使用缓存数据点。

您可以使用 Dataset.cache 函数和 Dataset.take 函数来完成此操作。

如果一切都在记忆中,就像做这样的事情一样简单:

def generate_example():
  i = 0
  while(True):
    print ('yielding value {}'.format(i))
    yield tf.random.uniform((64,64,3))
    i +=1

ds = tf.data.Dataset.from_generator(generate_example, tf.float32)

first_n_datapoints = ds.take(n).cache()

现在请注意,如果我将 n 设置为 3 say 然后做一些微不足道的事情,例如:

for i in first_n_datapoints.repeat():
  print ('')
  print (i.shape)

然后我看到确认前 3 个值被缓存的输出(对于生成的前 3 个值中的每一个,我只看到 yielding value {i} 输出一次:

yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...

如果所有内容都不适合内存,那么我们可以将文件路径传递给缓存函数,它将生成的张量缓存到磁盘。

更多信息在这里:https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache