使用数据集 api tensorflow 即时生成

on the fly generation with Dataset api tensorflow

我有一个函数可以生成特征张量和目标张量。例如。

x,t = myfunc() ##x,t tensors

如何将其与 TensorFlow 的数据集 API 集成以进行持续训练?理想情况下,我想使用数据集来设置批处理、转换等内容。

编辑澄清:问题是我不仅想将 x 和 t 放在我的图中,还想从中制作一个数据集,这样我就可以使用我为(普通)有限数据集实现的相同数据集处理我可以使用可初始化的迭代器加载到内存中并输入到同一个图中。

如果 x 和 t 是张量,您可以通过调用 tf.data.Dataset.from_tensorstf.data.Dataset.from_tensor_slices(文档 here)来创建数据集。

它们之间的区别在于from_tensors将输入张量组合成数据集中的单个元素。 from_tensor_slices 创建一个数据集,每个切片有一个元素。

假设 xttf.Tensor 对象,并且 my_func() 构建了一个 TensorFlow 图,您可以使用以下方法与 `Dataset.map():

# Creates an infinite dataset with a dummy value. You can make this finite by
# specifying an explicit number of elements to `repeat()`.
dummy_dataset = tf.data.Dataset.from_tensors(0).repeat(None)

# Evaluates `my_func` once for each element in `dummy_dataset`.
dataset = dummy_dataset.map(lambda _: my_func())