TensorFlow Debugger ValueError: Node name 'Add/x' is not found in partition graphs of device
TensorFlow Debugger ValueError: Node name 'Add/x' is not found in partition graphs of device
我正在研究 TensorFlow 1.6,我试图在我的程序中设置 TensorFlow 调试器 tfdbg。当我在 tfdbg 终端中输入命令 运行 时,出现以下错误:
Traceback (most recent call last):
File "/Users/Documents/imputation/main.py", line 346, in <module>
args_ = _Parser(description='Train/evaluate the network for incidents '
File "/Users/Documents/imputation/main.py", line 312, in parse_args
command(args, parser)
File "/Users/Documents/imputation/main.py", line 222, in _call
args_dict = _Train._call(namespace, parser)
File "/Users/Documents/imputation/main.py", line 151, in _call
train(**args_dict)
File "/Users/Documents/imputation/tf_impute.py", line 185, in train
mon_sess.run([train_op,
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 546, in run
run_metadata=run_metadata)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1022, in run
run_metadata=run_metadata)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1113, in run
raise six.reraise(*original_exc_info)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/six.py", line 693, in reraise
raise value
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1098, in run
return self._sess.run(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1178, in run
run_metadata=run_metadata))
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/wrappers/hooks.py", line 150, in after_run
self._session_wrapper.on_run_end(on_run_end_request)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/wrappers/local_cli_wrapper.py", line 323, in on_run_end
self._dump_root, partition_graphs=partition_graphs)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 495, in __init__
self._load_all_device_dumps(partition_graphs, validate)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 517, in _load_all_device_dumps
self._load_partition_graphs(partition_graphs, validate)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 797, in _load_partition_graphs
self._validate_dump_with_graphs(debug_graph.device_name)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 842, in _validate_dump_with_graphs
"device %s." % (datum.node_name, device_name))
ValueError: Node name 'Add/x' is not found in partition graphs of device /job:localhost/replica:0/task:0/device:CPU:0.
我也在查看 https://github.com/tensorflow/tensorflow/issues/8753 中的问题,其中讨论了类似的问题,但提供的解决方案对我不起作用。我已经尝试将 tfdbg 实现为会话的包装器,也在挂钩中。我实现 tfdbg 的代码部分如下所示:
class _LoggerHook(tf.train.SessionRunHook):
cumulative_loss = 0
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss)
def after_run(self, run_context, run_values):
loss_value = run_values.results
self.cumulative_loss += loss_value
if self._step == 0:
print('Starting training at %s' % datetime.now())
elif self._step % print_step == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
rms_error = math.sqrt(2 * self.cumulative_loss / print_step)
self.cumulative_loss = 0
examples_per_sec = print_step * batch_size / duration
sec_per_batch = float(duration / print_step)
format_str = (
'%s: %d examples, rms_error = %.6f (%.1f examples/sec; '
'%.3f sec/batch)')
print(format_str % (
datetime.now(), self._step * batch_size, rms_error,
examples_per_sec, sec_per_batch))
max_steps = epochs * (examples // batch_size)
model_saver = tf.train.Saver(var_list=tf.model_variables())
class _CheckpointSaverHook(CheckpointSaverHook):
def __init__(self, *args, **kwargs):
super(_CheckpointSaverHook, self).__init__(*args, **kwargs)
assert self._listeners == [], 'CheckpointSaverListener not ' \
'allowed'
def end(self, session):
class _FinalStepHook(FinalOpsHook):
def end(self, session):
super(_FinalStepHook, self).end(session)
print('Saving last checkpoint at step %d' % session.run(
global_step))
model_saver.save(session,
os.path.join(train_dir, "model.ckpt"),
global_step)
final_hook = _FinalStepHook([train_op, preds_update_op])
scaffold = tf.train.Scaffold(saver=model_saver)
logger_hook = _LoggerHook()
hooks = [_CheckpointSaverHook(checkpoint_dir=train_dir, save_secs=1000,
scaffold=scaffold),
tf.train.StopAtStepHook(last_step=max_steps - 1),
tf.train.NanTensorHook(loss), logger_hook, final_hook,
tf_debug.LocalCLIDebugHook()]
config = tf.ConfigProto(log_device_placement=log_device_placement)
config.gpu_options.allow_growth = True
start_train = time.time()
with tf.train.MonitoredTrainingSession(checkpoint_dir=train_dir,
hooks=hooks, config=config, save_checkpoint_secs=0,
scaffold=scaffold) as mon_sess:
try:
while not mon_sess.should_stop():
mon_sess.run([train_op,
# globals_preds
])
except OutOfRangeError as e:
print(e)
print('global step %s' % logger_hook._step)
except KeyboardInterrupt:
print('Train interrupted at global step %s' % logger_hook._step)
print('Training %d examples in %d epochs took %s' % (
examples, epochs, secs_to_time(time.time() - start_train)))
upload_timestamped_tar(s3_url, train_dir, keep_dir, keep_tar, wait)
return final_hook.final_ops_values[1]
你知道如何解决这个问题吗?
我现在解决了这个问题。问题是我在代码中的某处使用了加号 +
而不是 tf.add
。当我检查 Tensorboard 中的图表时,我意识到 "add/x" 节点已经存在,但带有一个小写字母,例如
here.
将我代码中的部分改为tf.add
后,Tensorboard中的节点也改为"Add/x",大写字母如here。最后,TensorFlow Debugger 能够正确识别节点并且现在可以正常工作了。
我正在研究 TensorFlow 1.6,我试图在我的程序中设置 TensorFlow 调试器 tfdbg。当我在 tfdbg 终端中输入命令 运行 时,出现以下错误:
Traceback (most recent call last):
File "/Users/Documents/imputation/main.py", line 346, in <module>
args_ = _Parser(description='Train/evaluate the network for incidents '
File "/Users/Documents/imputation/main.py", line 312, in parse_args
command(args, parser)
File "/Users/Documents/imputation/main.py", line 222, in _call
args_dict = _Train._call(namespace, parser)
File "/Users/Documents/imputation/main.py", line 151, in _call
train(**args_dict)
File "/Users/Documents/imputation/tf_impute.py", line 185, in train
mon_sess.run([train_op,
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 546, in run
run_metadata=run_metadata)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1022, in run
run_metadata=run_metadata)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1113, in run
raise six.reraise(*original_exc_info)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/six.py", line 693, in reraise
raise value
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1098, in run
return self._sess.run(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1178, in run
run_metadata=run_metadata))
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/wrappers/hooks.py", line 150, in after_run
self._session_wrapper.on_run_end(on_run_end_request)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/wrappers/local_cli_wrapper.py", line 323, in on_run_end
self._dump_root, partition_graphs=partition_graphs)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 495, in __init__
self._load_all_device_dumps(partition_graphs, validate)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 517, in _load_all_device_dumps
self._load_partition_graphs(partition_graphs, validate)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 797, in _load_partition_graphs
self._validate_dump_with_graphs(debug_graph.device_name)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow/python/debug/lib/debug_data.py", line 842, in _validate_dump_with_graphs
"device %s." % (datum.node_name, device_name))
ValueError: Node name 'Add/x' is not found in partition graphs of device /job:localhost/replica:0/task:0/device:CPU:0.
我也在查看 https://github.com/tensorflow/tensorflow/issues/8753 中的问题,其中讨论了类似的问题,但提供的解决方案对我不起作用。我已经尝试将 tfdbg 实现为会话的包装器,也在挂钩中。我实现 tfdbg 的代码部分如下所示:
class _LoggerHook(tf.train.SessionRunHook):
cumulative_loss = 0
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss)
def after_run(self, run_context, run_values):
loss_value = run_values.results
self.cumulative_loss += loss_value
if self._step == 0:
print('Starting training at %s' % datetime.now())
elif self._step % print_step == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
rms_error = math.sqrt(2 * self.cumulative_loss / print_step)
self.cumulative_loss = 0
examples_per_sec = print_step * batch_size / duration
sec_per_batch = float(duration / print_step)
format_str = (
'%s: %d examples, rms_error = %.6f (%.1f examples/sec; '
'%.3f sec/batch)')
print(format_str % (
datetime.now(), self._step * batch_size, rms_error,
examples_per_sec, sec_per_batch))
max_steps = epochs * (examples // batch_size)
model_saver = tf.train.Saver(var_list=tf.model_variables())
class _CheckpointSaverHook(CheckpointSaverHook):
def __init__(self, *args, **kwargs):
super(_CheckpointSaverHook, self).__init__(*args, **kwargs)
assert self._listeners == [], 'CheckpointSaverListener not ' \
'allowed'
def end(self, session):
class _FinalStepHook(FinalOpsHook):
def end(self, session):
super(_FinalStepHook, self).end(session)
print('Saving last checkpoint at step %d' % session.run(
global_step))
model_saver.save(session,
os.path.join(train_dir, "model.ckpt"),
global_step)
final_hook = _FinalStepHook([train_op, preds_update_op])
scaffold = tf.train.Scaffold(saver=model_saver)
logger_hook = _LoggerHook()
hooks = [_CheckpointSaverHook(checkpoint_dir=train_dir, save_secs=1000,
scaffold=scaffold),
tf.train.StopAtStepHook(last_step=max_steps - 1),
tf.train.NanTensorHook(loss), logger_hook, final_hook,
tf_debug.LocalCLIDebugHook()]
config = tf.ConfigProto(log_device_placement=log_device_placement)
config.gpu_options.allow_growth = True
start_train = time.time()
with tf.train.MonitoredTrainingSession(checkpoint_dir=train_dir,
hooks=hooks, config=config, save_checkpoint_secs=0,
scaffold=scaffold) as mon_sess:
try:
while not mon_sess.should_stop():
mon_sess.run([train_op,
# globals_preds
])
except OutOfRangeError as e:
print(e)
print('global step %s' % logger_hook._step)
except KeyboardInterrupt:
print('Train interrupted at global step %s' % logger_hook._step)
print('Training %d examples in %d epochs took %s' % (
examples, epochs, secs_to_time(time.time() - start_train)))
upload_timestamped_tar(s3_url, train_dir, keep_dir, keep_tar, wait)
return final_hook.final_ops_values[1]
你知道如何解决这个问题吗?
我现在解决了这个问题。问题是我在代码中的某处使用了加号 +
而不是 tf.add
。当我检查 Tensorboard 中的图表时,我意识到 "add/x" 节点已经存在,但带有一个小写字母,例如
here.
将我代码中的部分改为tf.add
后,Tensorboard中的节点也改为"Add/x",大写字母如here。最后,TensorFlow Debugger 能够正确识别节点并且现在可以正常工作了。