如何将值注入 TensorFlow 图的中间?
How to inject values into the middle of TensorFlow graph?
考虑以下代码:
x = tf.placeholder(tf.float32, (), name='x')
z = x + tf.constant(5.0)
y = tf.mul(z, tf.constant(0.5))
with tf.Session() as sess:
print(sess.run(y, feed_dict={x: 30}))
结果图是 x -> z -> y。有时我对从 x 一直计算 y 很感兴趣,但有时我从 z 开始并想将这个值注入到图中。所以 z 需要表现得像一个部分占位符。我该怎么做?
(对于任何感兴趣的人,为什么我需要这个。我正在使用一个自动编码器网络,该网络观察图像 x,生成中间压缩表示 z,然后计算图像 y 的重建。我想看看网络是什么当我为 z 注入不同的值时重建。)
按以下方式使用默认占位符:
x = tf.placeholder(tf.float32, (), name='x')
# z is a placeholder with default value
z = tf.placeholder_with_default(x+tf.constant(5.0), (), name='z')
y = tf.mul(z, tf.constant(0.5))
with tf.Session() as sess:
# and feed the z in
print(sess.run(y, feed_dict={z: 5}))
傻我
我不能评论你的post,@iramusa,所以我会给出一个答案。您不需要使用 placeholder_with_default。您可以将值输入到您想要的任何节点:
import tensorflow as tf
x = tf.placeholder(tf.float32,(), name='x')
z = x + tf.constant(5.0)
y = z*tf.constant(0.5)
with tf.Session() as sess:
print(sess.run(y, feed_dict={x: 2})) # get 3.5
print(sess.run(y, feed_dict={z: 5})) # get 2.5
print(sess.run(y, feed_dict={y: 5})) # get 5
考虑以下代码:
x = tf.placeholder(tf.float32, (), name='x')
z = x + tf.constant(5.0)
y = tf.mul(z, tf.constant(0.5))
with tf.Session() as sess:
print(sess.run(y, feed_dict={x: 30}))
结果图是 x -> z -> y。有时我对从 x 一直计算 y 很感兴趣,但有时我从 z 开始并想将这个值注入到图中。所以 z 需要表现得像一个部分占位符。我该怎么做?
(对于任何感兴趣的人,为什么我需要这个。我正在使用一个自动编码器网络,该网络观察图像 x,生成中间压缩表示 z,然后计算图像 y 的重建。我想看看网络是什么当我为 z 注入不同的值时重建。)
按以下方式使用默认占位符:
x = tf.placeholder(tf.float32, (), name='x')
# z is a placeholder with default value
z = tf.placeholder_with_default(x+tf.constant(5.0), (), name='z')
y = tf.mul(z, tf.constant(0.5))
with tf.Session() as sess:
# and feed the z in
print(sess.run(y, feed_dict={z: 5}))
傻我
我不能评论你的post,@iramusa,所以我会给出一个答案。您不需要使用 placeholder_with_default。您可以将值输入到您想要的任何节点:
import tensorflow as tf
x = tf.placeholder(tf.float32,(), name='x')
z = x + tf.constant(5.0)
y = z*tf.constant(0.5)
with tf.Session() as sess:
print(sess.run(y, feed_dict={x: 2})) # get 3.5
print(sess.run(y, feed_dict={z: 5})) # get 2.5
print(sess.run(y, feed_dict={y: 5})) # get 5