何时在 Tensorflow Estimator 中使用迭代器

When to use an iterator in Tensorflow Estimator

在 Tensorflow 指南中,指南在两个单独的位置描述了 Iris 数据示例的输入函数。一个输入函数 return 只是数据集本身,而另一个 return 是带有迭代器的数据集。

来自预制 Estimator 指南:https://www.tensorflow.org/guide/premade_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)

来自自定义估算器指南:https://www.tensorflow.org/guide/custom_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()

我很困惑哪一个是正确的,如果它们都用于不同的情况,那么什么时候 return 使用迭代器的数据集是正确的?

如果您的输入函数 return 是 tf.data.Dataset,则会在后台创建一个迭代器,其 get_next() 函数用于为模型提供输入。这在源代码中有些隐藏,参见 parse_input_fn_result here.

我相信这只是在最近的更新中实现的,所以旧的教程仍然在他们的输入函数中明确地 return get_next() 因为它是当时唯一的选择。使用两者应该没有区别,但是您可以通过 return 数据集而不是迭代器来节省一小部分代码。