如何检查 tf.data 数据集对象?

How to checkpoint tf.data Dataset objects?

在训练期间检查点时(在 crash/etc 的情况下)我保存了图形和参数,但不清楚如何对用于输入的新 tf.data 对象执行相同操作。

是否有一种直接的方法来检查这些点,以便我可以继续当前的纪元,或者恢复随机播放状态(可能来自种子?)

tf.contrib.data.make_saveable_from_iterator() function takes a tf.data.Iterator object and gives you back a "saveable object" that can be saved using a tf.train.Saver。它保存迭代器的整个状态,包括任何打乱的数据。

以下示例代码显示了如何将简单的迭代器添加到用于变量的同一检查点:

ds = tf.data.Dataset.range(10)
iterator = ds.make_initializable_iterator()

# [Build the training graph, using `iterator.get_next()` as the input.]

# Build the iterator SaveableObject.
saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)

# Add the SaveableObject to the SAVEABLE_OBJECTS collection so
# it will be saved automatically using a Saver.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)

# Create a saver that saves all objects in the `tf.GraphKeys.SAVEABLE_OBJECTS`
# collection.
saver = tf.train.Saver()

with tf.Session() as sess:
  while continue_training:

    # [Perform training.]

    if should_save_checkpoint:
      saver.save(sess, ...)

请注意,迭代器检查点支持目前(从 TensorFlow 1.8 开始)处于实验状态,因此检查点格式可能会从一个版本更改为下一个版本。