在 Keras 中使用 TF 数据集 API 的既定方法是用 `make_one_shot_iterator()` 喂养 `model.fit`,但是这个迭代器只适用于一个 Epoch

The established way to use TF Dataset API in Keras is to feed `model.fit` with `make_one_shot_iterator()`, But this iterator only good for one Epoch

编辑:

为了阐明为什么这个问题与建议的重复项不同,这个 SO 问题跟进了那些建议的重复项,即 Keras 到底在用那些 SO 问题中描述的技术做什么。建议的重复项指定在 model.fit 中使用数据集 API make_one_shot_iterator(),我的后续行动是 make_one_shot_iterator() 只能通过数据集一次,但是在给出的解决方案中,几个时期被指定。


这是对这些 SO 问题的跟进

Tensorflow keras with tf dataset input

其中 "Starting from Tensorflow 1.9, one can pass tf.data.Dataset object directly into keras.Model.fit() and it would act similar to fit_generator"。每个示例都有一个 TF 数据集一次性迭代器输入 Kera 的 model.fit。

例子如下

# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)

model = # your keras model here              
model.fit(
    training_set.make_one_shot_iterator(),
    steps_per_epoch=len(x_train) // 128,
    epochs=5,
    verbose = 1)

但是,根据 Tensorflow 数据集 API 指南(此处 https://www.tensorflow.org/guide/datasets ):

A one-shot iterator is the simplest form of iterator, which only supports iterating once through a dataset

所以它只适用于 1 个 epoch。但是,SO 问题中的代码指定了几个 epoch,上面的代码示例指定了 5 个 epoch。

这个矛盾有什么解释吗? Keras 是否知道当一次性迭代器遍历数据集时,它可以重新初始化和打乱数据?

您可以简单地将数据集对象传递给 model.fit,Keras 将处理迭代。 考虑其中一个预制数据集:

train, test = tf.keras.datasets.cifar10.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))

这将从 cifar10 数据集的训练数据创建数据集对象。在这种情况下,不需要解析函数。 如果您从包含 numpy 数组列表图像的路径创建数据集,您将需要一个。

dataset = tf.data.Dataset.from_tensor_slices((image_path, labels_path)) 

如果您需要一个函数来从文件名加载实际数据。没有 tf.read_file

就可以用同样的方式处理 Numpy 数组
def parse_func(filename):
    f = tf.read_file(filename)
    image = tf.image.decode_image(f)
    label = #get label from filename
    return image, label

然后您可以对这个数据集进行洗牌、批处理和映射任何解析函数。您可以控制使用洗牌缓冲区预加载多少示例。重复控制纪元计数,最好保留 None,这样它会无限期地重复。您可以使用普通批处理函数或结合

dataset = dataset.shuffle().repeat()
dataset.apply(tf.data.experimental.map_and_batch(map_func=parse_func, batch_size,num_parallel_batches))

然后数据集对象可以传递给model.fit model.fit(数据集,时期,steps_per_epoch)。请注意,steps_per_epoch 在这种情况下是必要的参数,它将定义何时开始新纪元。所以你必须提前知道时代的大小。