Tensorflow 自定义梯度的解析解释是什么?

What is the analytic interpretation for Tensorflow custom gradient?

在官方 tf.custom_gradient 文档中,它展示了如何为 log(1 + exp(x))

定义自定义渐变
@tf.custom_gradient
def log1pexp(x):
  e = tf.exp(x)
  def grad(dy):
    return dy * (1 - 1 / (1 + e))
  return tf.math.log(1 + e), grad

y = log(1 + exp(x))时,解析导数为dy/dx = (1 - 1 / (1 + exp(x)))

但是在代码中 def grad 表示其 dy * (1 - 1 / (1 + exp(x)))dy/dx = dy * (1 - 1 / (1 + exp(x))) 不是有效的等式。虽然 dx = dy * (1 - 1 / (1 + exp(x))) 是错误的,因为它应该是倒数。

grad 函数等同于什么?

您正在查看的额外 dy 是激活本身的价值。因为如果您查看优化器方程式,则需要将梯度与输出值相乘。因此,这就是这样做的原因。

我终于想通了。 dy 应称为 upstream_gradientupstream_dy_dx

根据链式法则我​​们知道

其中dx[i]/dx[i+1]是当前函数的梯度。

所以dy是这个函数之前所有上游梯度的乘积。

因此,如果您忘记乘以 dy,它实际上与 tf.stop_gradient

相同

这是一个演示这个的代码。完整笔记本 here

@tf.custom_gradient
def foo(x):
    tf.debugging.assert_rank(x, 0)

    def grad(dy_dx_upstream):
        dy_dx = 2 * x
        dy_dx_downstream = dy_dx * dy_dx_upstream
        tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
        return dy_dx_downstream
    
    y = x ** 2
    tf.print(f'x={x}\ty={y}')
    
    return y, grad


x = tf.constant(2.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = foo(foo(foo(x))) # y = x ** 8

tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')

输出

x=2.0   y=4.0
x=4.0   y=16.0
x=16.0  y=256.0
x=16.0  upstream=1.0    current=32.0        downstream=32.0
x=4.0   upstream=32.0   current=8.0     downstream=256.0
x=2.0   upstream=256.0  current=4.0     downstream=1024.0

final dy/dx=1024.0