在 Keras 中使用 TF 数据集 API 的既定方法是用 `make_one_shot_iterator()` 喂养 `model.fit`,但是这个迭代器只适用于一个 Epoch
The established way to use TF Dataset API in Keras is to feed `model.fit` with `make_one_shot_iterator()`, But this iterator only good for one Epoch
编辑:
为了阐明为什么这个问题与建议的重复项不同,这个 SO 问题跟进了那些建议的重复项,即 Keras 到底在用那些 SO 问题中描述的技术做什么。建议的重复项指定在 model.fit
中使用数据集 API make_one_shot_iterator()
,我的后续行动是 make_one_shot_iterator()
只能通过数据集一次,但是在给出的解决方案中,几个时期被指定。
这是对这些 SO 问题的跟进
Tensorflow keras with tf dataset input
其中 "Starting from Tensorflow 1.9, one can pass tf.data.Dataset object directly into keras.Model.fit() and it would act similar to fit_generator"。每个示例都有一个 TF 数据集一次性迭代器输入 Kera 的 model.fit。
例子如下
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)
model = # your keras model here
model.fit(
training_set.make_one_shot_iterator(),
steps_per_epoch=len(x_train) // 128,
epochs=5,
verbose = 1)
但是,根据 Tensorflow 数据集 API 指南(此处 https://www.tensorflow.org/guide/datasets ):
A one-shot iterator is the simplest form of iterator, which only
supports iterating once through a dataset
所以它只适用于 1 个 epoch。但是,SO 问题中的代码指定了几个 epoch,上面的代码示例指定了 5 个 epoch。
这个矛盾有什么解释吗? Keras 是否知道当一次性迭代器遍历数据集时,它可以重新初始化和打乱数据?
您可以简单地将数据集对象传递给 model.fit
,Keras 将处理迭代。
考虑其中一个预制数据集:
train, test = tf.keras.datasets.cifar10.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))
这将从 cifar10 数据集的训练数据创建数据集对象。在这种情况下,不需要解析函数。
如果您从包含 numpy 数组列表图像的路径创建数据集,您将需要一个。
dataset = tf.data.Dataset.from_tensor_slices((image_path, labels_path))
如果您需要一个函数来从文件名加载实际数据。没有 tf.read_file
就可以用同样的方式处理 Numpy 数组
def parse_func(filename):
f = tf.read_file(filename)
image = tf.image.decode_image(f)
label = #get label from filename
return image, label
然后您可以对这个数据集进行洗牌、批处理和映射任何解析函数。您可以控制使用洗牌缓冲区预加载多少示例。重复控制纪元计数,最好保留 None,这样它会无限期地重复。您可以使用普通批处理函数或结合
dataset = dataset.shuffle().repeat()
dataset.apply(tf.data.experimental.map_and_batch(map_func=parse_func, batch_size,num_parallel_batches))
然后数据集对象可以传递给model.fit
model.fit(数据集,时期,steps_per_epoch)。请注意,steps_per_epoch
在这种情况下是必要的参数,它将定义何时开始新纪元。所以你必须提前知道时代的大小。
编辑:
为了阐明为什么这个问题与建议的重复项不同,这个 SO 问题跟进了那些建议的重复项,即 Keras 到底在用那些 SO 问题中描述的技术做什么。建议的重复项指定在 model.fit
中使用数据集 API make_one_shot_iterator()
,我的后续行动是 make_one_shot_iterator()
只能通过数据集一次,但是在给出的解决方案中,几个时期被指定。
这是对这些 SO 问题的跟进
Tensorflow keras with tf dataset input
其中 "Starting from Tensorflow 1.9, one can pass tf.data.Dataset object directly into keras.Model.fit() and it would act similar to fit_generator"。每个示例都有一个 TF 数据集一次性迭代器输入 Kera 的 model.fit。
例子如下
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)
model = # your keras model here
model.fit(
training_set.make_one_shot_iterator(),
steps_per_epoch=len(x_train) // 128,
epochs=5,
verbose = 1)
但是,根据 Tensorflow 数据集 API 指南(此处 https://www.tensorflow.org/guide/datasets ):
A one-shot iterator is the simplest form of iterator, which only supports iterating once through a dataset
所以它只适用于 1 个 epoch。但是,SO 问题中的代码指定了几个 epoch,上面的代码示例指定了 5 个 epoch。
这个矛盾有什么解释吗? Keras 是否知道当一次性迭代器遍历数据集时,它可以重新初始化和打乱数据?
您可以简单地将数据集对象传递给 model.fit
,Keras 将处理迭代。
考虑其中一个预制数据集:
train, test = tf.keras.datasets.cifar10.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))
这将从 cifar10 数据集的训练数据创建数据集对象。在这种情况下,不需要解析函数。 如果您从包含 numpy 数组列表图像的路径创建数据集,您将需要一个。
dataset = tf.data.Dataset.from_tensor_slices((image_path, labels_path))
如果您需要一个函数来从文件名加载实际数据。没有 tf.read_file
def parse_func(filename):
f = tf.read_file(filename)
image = tf.image.decode_image(f)
label = #get label from filename
return image, label
然后您可以对这个数据集进行洗牌、批处理和映射任何解析函数。您可以控制使用洗牌缓冲区预加载多少示例。重复控制纪元计数,最好保留 None,这样它会无限期地重复。您可以使用普通批处理函数或结合
dataset = dataset.shuffle().repeat()
dataset.apply(tf.data.experimental.map_and_batch(map_func=parse_func, batch_size,num_parallel_batches))
然后数据集对象可以传递给model.fit
model.fit(数据集,时期,steps_per_epoch)。请注意,steps_per_epoch
在这种情况下是必要的参数,它将定义何时开始新纪元。所以你必须提前知道时代的大小。