用 tf.estimator 初始化 tf.contrib.data.Iterator

Initialization of tf.contrib.data.Iterator with tf.estimator

如果 tf.estimator.Estimator 也被使用,应该如何初始化 tf.contrib.data.Iterator

其中一个问题是输入图 (tf 图处理输入的部分) 应该在 intput_fn() 中定义 - 因为 tf.estimator 创建单独的图表。

这个要求使得很难访问迭代器init ops并传递它们to tf.estimator(传递操作可以在调用train/evaluate/predict时以挂钩的形式完成)。

使用SessionManager作为钩子可以解决同样的问题。

sm = tf.train.SessionManager(local_init_op=iterator_init_op)
...
estimator = tf.train.Estimator(...)
estimator.train(input_fn, hooks=[sm], steps=None, max_steps=None)

一个选项是将您的 input_fn 包装在另一个设置简单 SessionRunHook init_hook 的函数中。所有操作都在 input_fn 内定义,它在与模型的其余部分相同的图形中被调用,但您可以从中将 iterator_init_op 设置为 init_hook 上的属性。

def get_input_fn(mode="train"):
    init_hook = IteratorInitHook()

    def input_fn():
        ...
        iterator = dataset.make_initializable_iterator()
        init_hook.iterator_init_op = iterator.initializer

    return input_fn, init_hook

class IteratorInitHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord):
        session.run(self.iterator_init_op)

现在,在构建 Experiment 时,您可以获得这些输入函数和初始化挂钩,它们会在创建 train/eval 会话时调用。它应该与 estimator.train.

等效
train_input_fn, train_init_hook = get_input_fn("train")
test_input_fn, test_init_hook = get_input_fn("test")

return tf.contrib.learn.Experiment(
    estimator=estimator,
    train_input_fn=train_input_fn,
    eval_input_fn=test_input_fn,
    train_monitors=[train_init_hook],
    eval_hooks=[test_init_hook],
)