可重新初始化的迭代器解决了什么问题?
What problem does a reinitializable iterator solve?
A reinitializable iterator can be initialized from multiple different
Dataset objects. For example, you might have a training input pipeline
that uses random perturbations to the input images to improve
generalization, and a validation input pipeline that evaluates
predictions on unmodified data. These pipelines will typically use
different Dataset objects that have the same structure (i.e. the same
types and compatible shapes for each component).
给出了以下示例:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
不清楚这种复杂性的好处是什么。
为什么不简单地创建 2 个不同的迭代器?
重新初始化迭代器的最初动机如下:
用户的输入数据在两个或多个tf.data.Dataset
结构相同但管道定义不同的对象中。
例如,您可能有一个在 Dataset.map()
中进行扩充的训练数据管道和一个生成原始示例的评估数据管道,但它们都会生成具有相同结构的批次(就张量的数量、它们的元素类型、形状等)。
用户将定义一个单一的训练图,该图从 tf.data.Iterator
中获取输入,使用 Iterator.from_structure()
创建。
然后用户可以通过重新初始化来自数据集之一的迭代器在不同的输入数据源之间切换。
事后看来,可重新初始化的迭代器已被证明很难用于其预期目的。在 TensorFlow 2.0(或 1.x 启用急切执行)中,使用惯用 Python for
循环和高级训练 API 在不同数据集上创建迭代器要容易得多:
tf.enable_eager_execution()
model = ... # A `tf.keras.Model`, or some other class exposing `fit()` and `evaluate()` methods.
train_data = ... # A `tf.data.Dataset`.
eval_data = ... # A `tf.data.Dataset`.
for i in range(NUM_EPOCHS):
model.fit(train_data, ...)
# Evaluate every 5 epochs.
if i % 5 == 0:
model.evaluate(eval_data, ...)
A reinitializable iterator can be initialized from multiple different Dataset objects. For example, you might have a training input pipeline that uses random perturbations to the input images to improve generalization, and a validation input pipeline that evaluates predictions on unmodified data. These pipelines will typically use different Dataset objects that have the same structure (i.e. the same types and compatible shapes for each component).
给出了以下示例:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
不清楚这种复杂性的好处是什么。
为什么不简单地创建 2 个不同的迭代器?
重新初始化迭代器的最初动机如下:
用户的输入数据在两个或多个
tf.data.Dataset
结构相同但管道定义不同的对象中。例如,您可能有一个在
Dataset.map()
中进行扩充的训练数据管道和一个生成原始示例的评估数据管道,但它们都会生成具有相同结构的批次(就张量的数量、它们的元素类型、形状等)。用户将定义一个单一的训练图,该图从
tf.data.Iterator
中获取输入,使用Iterator.from_structure()
创建。然后用户可以通过重新初始化来自数据集之一的迭代器在不同的输入数据源之间切换。
事后看来,可重新初始化的迭代器已被证明很难用于其预期目的。在 TensorFlow 2.0(或 1.x 启用急切执行)中,使用惯用 Python for
循环和高级训练 API 在不同数据集上创建迭代器要容易得多:
tf.enable_eager_execution()
model = ... # A `tf.keras.Model`, or some other class exposing `fit()` and `evaluate()` methods.
train_data = ... # A `tf.data.Dataset`.
eval_data = ... # A `tf.data.Dataset`.
for i in range(NUM_EPOCHS):
model.fit(train_data, ...)
# Evaluate every 5 epochs.
if i % 5 == 0:
model.evaluate(eval_data, ...)