Tensorflow Dataset API 在完成一个 epoch 后恢复 Iterator

Tensorflow Dataset API restore Iterator after completing one epoch

我有 190 个特征和标签,我的批量大小是 20,但在 9 次迭代后 tf.reshape 返回异常 重塑的输入是一个具有 21 个值的张量,但请求的形状有 60 并且我知道这是由于 Iterator.get_next() 造成的。如何恢复我的迭代器以便它再次从头开始提供批处理服务?

如果您想从 Dataset 的开头重新启动 tf.data.Iterator,请考虑使用 initializable 迭代器,它有一个操作,您可以 运行 re-initialize 迭代器:

dataset = ...  # A `tf.data.Dataset` instance.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

train_op = ...  # Something that depends on `next_element`.

for _ in range(NUM_EPOCHS):
  # Initialize the iterator at the beginning of `dataset`.
  sess.run(iterator.initializer)

  # Loop over the examples in `iterator`, running `train_op`.
  try:
    while True:
      sess.run(train_op)

  except tf.errors.OutOfRangeError:  # Thrown at the end of the epoch.
    pass

  # Perform any per-epoch computations here.

有关不同类型 Iterator 的更多详细信息,请参阅 the tf.data programmer's guide