如何将来自 Tensorflow 数据集 API 的可馈送迭代器与 MonitoredTrainingSession 一起使用?
How to use feedable iterator from Tensorflow Dataset API along with MonitoredTrainingSession?
Tensorflow programmer's guide 建议使用可馈送迭代器在训练和验证数据集之间切换,而无需重新初始化迭代器。主要是需要喂手柄来选择。
如何与tf.train.MonitoredTrainingSession
一起使用?
以下方法失败并出现 "RuntimeError: Graph is finalized and cannot be modified." 错误。
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
如何兼顾 MonitoredTrainingSession 的便利性和同时迭代训练和验证数据集?
我从 Tensorflow GitHub 问题中得到了答案 - https://github.com/tensorflow/tensorflow/issues/12859
解决方案是在创建 MonitoredSession
之前调用 iterator.string_handle()
。
import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
print('train', sess.run(next_batch, feed_dict={handle: handle_train}))
if step % 3 == 0:
print('val', sess.run(next_batch, feed_dict={handle: handle_val}))
Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)
@Michael Jaison G 回答正确。但是,当您还想使用某些需要评估图形部分的 session_run_hooks 时,它不起作用,例如LoggingTensorHook 或 SummarySaverHook。
下面的示例将导致错误:
import tensorflow as tf
dataset_train = tf.data.Dataset.range(10)
dataset_val = tf.data.Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
feature = iterator.get_next()
pred = feature * feature
tf.summary.scalar('pred', pred)
global_step = tf.train.create_global_step()
summary_hook = tf.train.SummarySaverHook(save_steps=5,
output_dir="summaries", summary_op=tf.summary.merge_all())
with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
feat = sess.run(feature, feed_dict={handle: handle_train})
pred_ = sess.run(pred, feed_dict={handle: handle_train})
print('train: ', feat)
print('pred: ', pred_)
if step % 3 == 0:
print('val', sess.run(feature, feed_dict={handle: handle_val}))
这将失败并出现错误:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
[[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
[[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
原因是挂钩会在第一个 session.run([iter_train_handle, iter_val_handle]) 时尝试评估图,这显然不包含句柄feed_dict还没有。
解决方法是覆盖导致问题的挂钩并更改 before_run 和 after_run 中的代码以仅评估 session.run 包含句柄的调用 feed_dict(您可以通过 before_run 和 after_run 的 run_context 参数访问当前 session.run 调用的 feed_dict)
或者您可以使用最新的 Tensorflow 大师 (post-1.4),它向 MonitoredSession 添加了一个 run_step_fn 函数,它允许您指定以下 step_fn 这将避免错误(以评估 if 语句 TrainingIteration 次数为代价...)
def step_fn(step_context):
if handle_train is None:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
return step_context.run_with_hooks(fetches=..., feed_dict=...)
在 mot_session 中有一个将占位符与 SessionRunHook 一起使用的演示。
这个演示是关于通过提供 diff handle_string.
来切换数据集
顺便说一句,我已经尝试了所有解决方案,但只有这个有效。
Tensorflow programmer's guide 建议使用可馈送迭代器在训练和验证数据集之间切换,而无需重新初始化迭代器。主要是需要喂手柄来选择。
如何与tf.train.MonitoredTrainingSession
一起使用?
以下方法失败并出现 "RuntimeError: Graph is finalized and cannot be modified." 错误。
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
如何兼顾 MonitoredTrainingSession 的便利性和同时迭代训练和验证数据集?
我从 Tensorflow GitHub 问题中得到了答案 - https://github.com/tensorflow/tensorflow/issues/12859
解决方案是在创建 MonitoredSession
之前调用 iterator.string_handle()
。
import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
print('train', sess.run(next_batch, feed_dict={handle: handle_train}))
if step % 3 == 0:
print('val', sess.run(next_batch, feed_dict={handle: handle_val}))
Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)
@Michael Jaison G 回答正确。但是,当您还想使用某些需要评估图形部分的 session_run_hooks 时,它不起作用,例如LoggingTensorHook 或 SummarySaverHook。 下面的示例将导致错误:
import tensorflow as tf
dataset_train = tf.data.Dataset.range(10)
dataset_val = tf.data.Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
feature = iterator.get_next()
pred = feature * feature
tf.summary.scalar('pred', pred)
global_step = tf.train.create_global_step()
summary_hook = tf.train.SummarySaverHook(save_steps=5,
output_dir="summaries", summary_op=tf.summary.merge_all())
with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
feat = sess.run(feature, feed_dict={handle: handle_train})
pred_ = sess.run(pred, feed_dict={handle: handle_train})
print('train: ', feat)
print('pred: ', pred_)
if step % 3 == 0:
print('val', sess.run(feature, feed_dict={handle: handle_val}))
这将失败并出现错误:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
[[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
[[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
原因是挂钩会在第一个 session.run([iter_train_handle, iter_val_handle]) 时尝试评估图,这显然不包含句柄feed_dict还没有。
解决方法是覆盖导致问题的挂钩并更改 before_run 和 after_run 中的代码以仅评估 session.run 包含句柄的调用 feed_dict(您可以通过 before_run 和 after_run 的 run_context 参数访问当前 session.run 调用的 feed_dict)
或者您可以使用最新的 Tensorflow 大师 (post-1.4),它向 MonitoredSession 添加了一个 run_step_fn 函数,它允许您指定以下 step_fn 这将避免错误(以评估 if 语句 TrainingIteration 次数为代价...)
def step_fn(step_context):
if handle_train is None:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
return step_context.run_with_hooks(fetches=..., feed_dict=...)
在 mot_session 中有一个将占位符与 SessionRunHook 一起使用的演示。 这个演示是关于通过提供 diff handle_string.
来切换数据集顺便说一句,我已经尝试了所有解决方案,但只有这个有效。