tf.where returns NaN 的 TensorFlow 梯度不应该

TensorFlow gradient with tf.where returns NaN when it shouldn't

以下是可重现的代码。如果你 运行 它,你会看到在第一个 sess 运行 中,结果是 nan,而第二种情况给出了正确的梯度值 0.5。但是根据 tf.where 和指定的条件,它们应该 return 相同的值。我也不明白为什么 tf.where 函数梯度在 1 或 -1 时为 nan,这对我来说似乎是完全好的输入值。

tf.reset_default_graph()
x = tf.get_variable('x', shape=[1])
condition = tf.less(x, 0.0)
output = tf.where(condition, -tf.log(-x + 1), tf.log(x + 1))
deriv = tf.gradients(output, x)
with tf.Session() as sess:
    print(sess.run(deriv, {x:np.array([-1])}))

logg = -tf.log(-x+1)
derivv = tf.gradients(logg, x)
with tf.Session() as sess:
    print(sess.run(derivv, {x:np.array([-1])}))

感谢评论!

如@mikkola 提供的 github issue 中所述,问题源于 tf.where 的内部实现。基本上,计算两个备选方案(及其梯度),并且通过条件的乘法只选择正确的部分。 las,如果对于 未选择 的部分,渐变是 infnan,即使乘以 0,您最终也会得到 nan传播到结果。

由于该问题已于 2016 年 5 月提交(即 tensorflow v0.7!)且此后未修补,因此可以放心地假设这不会很快出现并开始寻找解决方法。

修复它的最简单方法是修改您的语句,使它们始终有效且可微分,即使对于不打算选择的值也是如此。

一种通用技术是将输入值裁剪到其有效域内。因此,例如,在您的情况下,您可以使用

cond = tf.less(x, 0.0)
output = tf.where(cond,
  -tf.log(-tf.where(cond, x, 0) + 1),
  tf.log(tf.where(cond, 0, x) + 1))

在您的特定情况下,使用

会更简单
output = tf.sign(x) * tf.log(tf.abs(x) + 1)