避免使用分配操作使张量流图混乱
Avoid cluttering the tensorflow graph with assign operations
我必须运行像下面的代码
import tensorflow as tf
sess = tf.Session()
x = tf.Variable(42.)
for i in range(10000):
sess.run(x.assign(42.))
sess.run(x)
print(i)
几次。实际代码要复杂得多,使用的变量也更多。
问题是 TensorFlow 图随着每个实例化的赋值运算而增长,这使得图增长,最终减慢了计算速度。
我可以使用feed_dict=
设置值,但我想将我的状态保留在图表中,以便我可以在其他地方轻松查询它。
在这种情况下,有什么方法可以避免使当前图形混乱吗?
我想我已经找到了一个很好的解决方案:
我定义了一个占位符 y
并创建了一个将 y
的值分配给 x
的操作。
然后我可以重复使用该操作,使用 feed_dict={y: value}
为 x 分配一个新值。
这不会向图中添加另一个操作。
事实证明,循环运行 比以前快 很多。
import tensorflow as tf
sess = tf.Session()
x = tf.Variable(42.)
y = tf.placeholder(dtype=tf.float32)
assign = x.assign(y)
sess.run(tf.initialize_all_variables())
for i in range(10000):
sess.run(assign, feed_dict={y: i})
print(i, sess.run(x))
每次调用sess.run(x.assign(42.))
发生了两件事:(i) 一个新的 assign
操作被添加到计算图 sess.graph
,(ii) 新添加的操作被执行。难怪如果循环重复多次,图形会变得非常大。如果您在执行前定义赋值操作(下例中的 asgnmnt_operation
),则只向图中添加一个操作,因此性能非常好:
import tensorflow as tf
x = tf.Variable(42.)
c = tf.constant(42.)
asgnmnt_operation = x.assign(c)
sess = tf.Session()
for i in range(10000):
sess.run(asgnmnt_operation)
sess.run(x)
print(i)
我必须运行像下面的代码
import tensorflow as tf
sess = tf.Session()
x = tf.Variable(42.)
for i in range(10000):
sess.run(x.assign(42.))
sess.run(x)
print(i)
几次。实际代码要复杂得多,使用的变量也更多。 问题是 TensorFlow 图随着每个实例化的赋值运算而增长,这使得图增长,最终减慢了计算速度。
我可以使用feed_dict=
设置值,但我想将我的状态保留在图表中,以便我可以在其他地方轻松查询它。
在这种情况下,有什么方法可以避免使当前图形混乱吗?
我想我已经找到了一个很好的解决方案:
我定义了一个占位符 y
并创建了一个将 y
的值分配给 x
的操作。
然后我可以重复使用该操作,使用 feed_dict={y: value}
为 x 分配一个新值。
这不会向图中添加另一个操作。
事实证明,循环运行 比以前快 很多。
import tensorflow as tf
sess = tf.Session()
x = tf.Variable(42.)
y = tf.placeholder(dtype=tf.float32)
assign = x.assign(y)
sess.run(tf.initialize_all_variables())
for i in range(10000):
sess.run(assign, feed_dict={y: i})
print(i, sess.run(x))
每次调用sess.run(x.assign(42.))
发生了两件事:(i) 一个新的 assign
操作被添加到计算图 sess.graph
,(ii) 新添加的操作被执行。难怪如果循环重复多次,图形会变得非常大。如果您在执行前定义赋值操作(下例中的 asgnmnt_operation
),则只向图中添加一个操作,因此性能非常好:
import tensorflow as tf
x = tf.Variable(42.)
c = tf.constant(42.)
asgnmnt_operation = x.assign(c)
sess = tf.Session()
for i in range(10000):
sess.run(asgnmnt_operation)
sess.run(x)
print(i)