你如何获得tensorflow中模型的输出和输入值?
How do you obtain the output and input values of a model in tensorflow?
我正在研究 GAN,并决定使用 HyperGAN 来实现我的算法。它是使用 TensorFlow 的 DCGAN 包装器。 HyperGAN 使用 TF
的检查点方法保存输出。
后来,我尝试运行加载模型使用:
import tensorflow as tf
sess=tf.Session()
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())
但是,由于它是 GAN,因此需要输入潜在向量并输出图像。这是使用
完成的
out_image = sess.run(last_node, feed_dict(input_node: value))
但是自从我加载模型后,我不知道最后一个节点的名称是什么以及输入节点占位符的名称是什么。我如何获取最初用于创建图形的名称?我尝试使用 TensorBoard
进行可视化,但图表很大,因此卡住了。
您应该尝试打印图中的张量列表:
with tf.Graph().as_default() as graph:
....
count = 0
for op in graph.get_operations():
print op.values()
count+=1
if count == 50:
assert False
为了查看图表的前 50 个节点,您将看到类似这样的内容:
(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,)
我把计数放在那里是因为通常终端会打印出如此多的张量,以至于初始输入节点名称在终端中消失。
最后简单的注释掉计数使用的行:
#count = 0
for op in graph.get_operations():
print op.values()
#count+=1
#if count == 50:
# assert False
打印最后几个节点(即您的输出节点)。
我正在研究 GAN,并决定使用 HyperGAN 来实现我的算法。它是使用 TensorFlow 的 DCGAN 包装器。 HyperGAN 使用 TF
的检查点方法保存输出。
后来,我尝试运行加载模型使用:
import tensorflow as tf
sess=tf.Session()
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())
但是,由于它是 GAN,因此需要输入潜在向量并输出图像。这是使用
完成的out_image = sess.run(last_node, feed_dict(input_node: value))
但是自从我加载模型后,我不知道最后一个节点的名称是什么以及输入节点占位符的名称是什么。我如何获取最初用于创建图形的名称?我尝试使用 TensorBoard
进行可视化,但图表很大,因此卡住了。
您应该尝试打印图中的张量列表:
with tf.Graph().as_default() as graph:
....
count = 0
for op in graph.get_operations():
print op.values()
count+=1
if count == 50:
assert False
为了查看图表的前 50 个节点,您将看到类似这样的内容:
(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,)
我把计数放在那里是因为通常终端会打印出如此多的张量,以至于初始输入节点名称在终端中消失。
最后简单的注释掉计数使用的行:
#count = 0
for op in graph.get_operations():
print op.values()
#count+=1
#if count == 50:
# assert False
打印最后几个节点(即您的输出节点)。