tf.Print() 导致梯度错误
tf.Print() causing an error with gradients
我试图使用 tf.Print 调试语句来更好地理解 compute_gradients() 中报告的梯度和变量的格式,但 运行 出现了意想不到的问题。训练例程和调试例程(gvdebug)如下:
def gvdebug(g, v):
#g = tf.Print(g,[g],'G: ')
#v = tf.Print(v,[v],'V: ')
g2 = tf.zeros_like(g, dtype=tf.float32)
v2 = tf.zeros_like(v, dtype=tf.float32)
g2 = g
v2 = v
return g2,v2
# Define training operation
def training(loss, global_step, learning_rate=0.1):
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
gv2 = [gvdebug(gv[0], gv[1]) for gv in grads_and_vars]
train_op = optimizer.apply_gradients(gv2, global_step=global_step)
return train_op
此代码工作正常(但不打印),但如果我取消注释 gvdebug() 中的两行 tf.Print 我会收到来自 apply_gradients 的错误消息:'TypeError: Variable must be a tf.Variable' .我以为 tf.Print 只是通过了张量——我做错了什么?
TL;DR
不要尝试 tf.Print
gv[1]
因为它是 tf.Variable
。它就像一个指向在 gv[0]
.
中创建 gradient
的变量的指针
更多信息
当你 运行 compute_gradients
它 returns 列表 gradients
and their corresponding tf.Variable
s.
grads_and_vars
的每个元素都是一个Tensor
和一个tf.Variable
。需要注意的是不是变量的值
删除 v = tf.Print(v,[v],'V: ')
后你的代码对我有用
我试图使用 tf.Print 调试语句来更好地理解 compute_gradients() 中报告的梯度和变量的格式,但 运行 出现了意想不到的问题。训练例程和调试例程(gvdebug)如下:
def gvdebug(g, v):
#g = tf.Print(g,[g],'G: ')
#v = tf.Print(v,[v],'V: ')
g2 = tf.zeros_like(g, dtype=tf.float32)
v2 = tf.zeros_like(v, dtype=tf.float32)
g2 = g
v2 = v
return g2,v2
# Define training operation
def training(loss, global_step, learning_rate=0.1):
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
gv2 = [gvdebug(gv[0], gv[1]) for gv in grads_and_vars]
train_op = optimizer.apply_gradients(gv2, global_step=global_step)
return train_op
此代码工作正常(但不打印),但如果我取消注释 gvdebug() 中的两行 tf.Print 我会收到来自 apply_gradients 的错误消息:'TypeError: Variable must be a tf.Variable' .我以为 tf.Print 只是通过了张量——我做错了什么?
TL;DR
不要尝试 tf.Print
gv[1]
因为它是 tf.Variable
。它就像一个指向在 gv[0]
.
gradient
的变量的指针
更多信息
当你 运行 compute_gradients
它 returns 列表 gradients
and their corresponding tf.Variable
s.
grads_and_vars
的每个元素都是一个Tensor
和一个tf.Variable
。需要注意的是不是变量的值
删除 v = tf.Print(v,[v],'V: ')