tf.train.MonitoredTrainingSession 和数据集中可重新初始化的迭代器
tf.train.MonitoredTrainingSession and reinitializable iterator from Dataset
似乎 MonitoredTrainingSession 在第一次调用 .运行(..) 之前做了一些操作(记录?),这意味着当我这样做时:
train_data = reader.traindata() # returns a tf.contrib.data.Dataset
it = tf.contrib.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
init_train = it.make_initializer(train_data)
ne = it.get_next()
ts = tf.train.MonitoredTrainingSession(checkpoint_dir=save_path)
... no calls to ts.run ...
ts.run(init_train)
这会产生错误:
FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element
所以它看起来好像 MonitoredTrainingSession 在 运行 执行我提供给它的操作之前正在执行一些操作,这使得无法将数据集中的可重新初始化迭代器一起使用。
我确定我错过了什么,很想听听:-)
看来 Tensorflow 中还没有直接的解决方案。是的,奇怪的是他们没有完全支持数据集 API.
原因是监控会话在从检查点加载时跳到运行init_op
。因此迭代器初始值设定项应该是局部变量。
本期给出了当前的变通建议 - https://github.com/tensorflow/tensorflow/issues/12859
scaffold = tf.train.Scaffold(local_init_op=tf.group(tf.local_variables_initializer(),
init_train))
with tf.train.MonitoredTrainingSession(scaffold=scaffold,
checkpoint_dir=checkpoint_dir) as sess:
while not sess.should_stop():
sess.run(train_op)
似乎 MonitoredTrainingSession 在第一次调用 .运行(..) 之前做了一些操作(记录?),这意味着当我这样做时:
train_data = reader.traindata() # returns a tf.contrib.data.Dataset
it = tf.contrib.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
init_train = it.make_initializer(train_data)
ne = it.get_next()
ts = tf.train.MonitoredTrainingSession(checkpoint_dir=save_path)
... no calls to ts.run ...
ts.run(init_train)
这会产生错误:
FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element
所以它看起来好像 MonitoredTrainingSession 在 运行 执行我提供给它的操作之前正在执行一些操作,这使得无法将数据集中的可重新初始化迭代器一起使用。
我确定我错过了什么,很想听听:-)
看来 Tensorflow 中还没有直接的解决方案。是的,奇怪的是他们没有完全支持数据集 API.
原因是监控会话在从检查点加载时跳到运行init_op
。因此迭代器初始值设定项应该是局部变量。
本期给出了当前的变通建议 - https://github.com/tensorflow/tensorflow/issues/12859
scaffold = tf.train.Scaffold(local_init_op=tf.group(tf.local_variables_initializer(),
init_train))
with tf.train.MonitoredTrainingSession(scaffold=scaffold,
checkpoint_dir=checkpoint_dir) as sess:
while not sess.should_stop():
sess.run(train_op)