在 Tensorflow 中的线程之间共享变量
Sharing variables between threads in Tensorflow
我正在尝试使用 Python 个线程通过 TensorFlow 实现异步梯度下降。在主代码中,我定义了图表,包括一个训练操作,它获取一个变量来保持对 global_step
:
的计数
with tf.variable_scope("scope_global_step") as scope_global_step:
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
train_op = optimizer.minimize(loss, global_step=global_step)
如果我打印 global_step
的名字,我得到:
scope_global_step/global_step:0
主代码还启动了几个线程来执行一个training
方法:
threads = [threading.Thread(target=training, args=(sess, train_op, loss, scope_global_step)) for i in xrange(NUM_TRAINING_THREADS)]
for t in threads: t.start()
如果 global_step
的值大于或等于 FLAGS.max_steps
,我希望每个线程都停止执行。为此,我构建了 training
方法,如下所示:
def training(sess, train_op, loss, scope_global_step):
while (True):
_, loss_value = sess.run([train_op, loss])
with tf.variable_scope(scope_global_step, reuse=True):
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
global_step = global_step.eval(session=sess)
if global_step >= FLAGS.max_steps: break
失败并显示消息:
ValueError: Under-sharing: Variable scope_global_step/global_step does not exist, disallowed. Did you mean to set reuse=None in VarScope?
我可以看到 :0
在第一次创建时添加到变量的名称中,当我尝试检索它时,没有使用该后缀。为什么是这样?
如果我在尝试检索它时手动将后缀添加到变量的名称中,它仍然声称该变量不存在。为什么 TensorFlow 找不到变量?变量不应该在线程之间自动共享吗?我的意思是,所有线程都在同一个会话中 运行 ,对吗?
还有另一个与我的 training
方法相关的问题:将 global_step.eval(session=sess)
再次执行图形,还是在执行 [=] 之后获取分配给 gloabl_step
的值23=] 和 loss
操作?一般来说,从要在 Python 代码中使用的变量中获取值的推荐方法是什么?
TL;DR: 传递您在第一个代码片段中创建的 global_step
tf.Variable
对象作为训练线程参数之一,并调用sess.run(global_step)
在传入的变量上。
作为一般规则,您的训练循环(尤其是单独线程中的训练循环)不应修改图形。 tf.variable_scope()
上下文管理器和 tf.get_variable()
可以 修改图形(尽管它们并不总是),因此您不应在训练循环中使用它们。最安全的做法是在创建训练线程时将 global_step
对象(您首先创建的)作为 args
元组之一传递。然后你可以简单地将你的训练函数重写为:
def training(sess, train_op, loss, global_step):
while (True):
_, loss_value = sess.run([train_op, loss])
current_step = sess.run(global_step)
if current_step >= FLAGS.max_steps: break
为了回答你的其他问题,运行 global_step.eval(session=sess)
或 sess.run(global_step)
只获取 global_step
变量的当前值,不会重新执行其余的你的图表。这是获取 tf.Variable
值以用于 Python 代码的推荐方法。
我正在尝试使用 Python 个线程通过 TensorFlow 实现异步梯度下降。在主代码中,我定义了图表,包括一个训练操作,它获取一个变量来保持对 global_step
:
with tf.variable_scope("scope_global_step") as scope_global_step:
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
train_op = optimizer.minimize(loss, global_step=global_step)
如果我打印 global_step
的名字,我得到:
scope_global_step/global_step:0
主代码还启动了几个线程来执行一个training
方法:
threads = [threading.Thread(target=training, args=(sess, train_op, loss, scope_global_step)) for i in xrange(NUM_TRAINING_THREADS)]
for t in threads: t.start()
如果 global_step
的值大于或等于 FLAGS.max_steps
,我希望每个线程都停止执行。为此,我构建了 training
方法,如下所示:
def training(sess, train_op, loss, scope_global_step):
while (True):
_, loss_value = sess.run([train_op, loss])
with tf.variable_scope(scope_global_step, reuse=True):
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
global_step = global_step.eval(session=sess)
if global_step >= FLAGS.max_steps: break
失败并显示消息:
ValueError: Under-sharing: Variable scope_global_step/global_step does not exist, disallowed. Did you mean to set reuse=None in VarScope?
我可以看到 :0
在第一次创建时添加到变量的名称中,当我尝试检索它时,没有使用该后缀。为什么是这样?
如果我在尝试检索它时手动将后缀添加到变量的名称中,它仍然声称该变量不存在。为什么 TensorFlow 找不到变量?变量不应该在线程之间自动共享吗?我的意思是,所有线程都在同一个会话中 运行 ,对吗?
还有另一个与我的 training
方法相关的问题:将 global_step.eval(session=sess)
再次执行图形,还是在执行 [=] 之后获取分配给 gloabl_step
的值23=] 和 loss
操作?一般来说,从要在 Python 代码中使用的变量中获取值的推荐方法是什么?
TL;DR: 传递您在第一个代码片段中创建的 global_step
tf.Variable
对象作为训练线程参数之一,并调用sess.run(global_step)
在传入的变量上。
作为一般规则,您的训练循环(尤其是单独线程中的训练循环)不应修改图形。 tf.variable_scope()
上下文管理器和 tf.get_variable()
可以 修改图形(尽管它们并不总是),因此您不应在训练循环中使用它们。最安全的做法是在创建训练线程时将 global_step
对象(您首先创建的)作为 args
元组之一传递。然后你可以简单地将你的训练函数重写为:
def training(sess, train_op, loss, global_step):
while (True):
_, loss_value = sess.run([train_op, loss])
current_step = sess.run(global_step)
if current_step >= FLAGS.max_steps: break
为了回答你的其他问题,运行 global_step.eval(session=sess)
或 sess.run(global_step)
只获取 global_step
变量的当前值,不会重新执行其余的你的图表。这是获取 tf.Variable
值以用于 Python 代码的推荐方法。