如何在嵌套渐变带中重用内部渐变?

How to reuse the inner gradient in nested gradient tapes?

我正在开发 tensorflow 1.15 中的一个例程,该例程为不同的向量评估多个 hessian 向量积

def hessian_v_prod(self, v):
    with tf.GradientTape() as t1:
        with tf.GradientTape() as t2:
            # evaluate loss which uses self.variables
            loss_val = self.loss()
        grad = t2.gradient(loss_val, self.variables)
        v_hat = tf.reduce(tf.multiply(v, grad))

    return t1.gradient(v_hat, self.variables)

每次调用此函数时,它都必须评估内循环并计算梯度,但无论 v 的值如何,这都是相同的。每次调用此函数时如何重用 grad 值?

我看到有一个选项可以将磁带创建为 tf.GradientTape(persist=True),它可以保留磁带的资源,但不知道如何将其合并到我的功能中。

我不得不深入研究 GradientTape 的内部工作原理,但设法搞清楚了。 在这里分享给可能有同样问题的其他人。 剧透警告:有点黑!

首先,调用时实际发生了什么

with tf.GradientTape() as tape:
    loss_value = self.loss()
tape.gradient(loss_value, vars)

为了找出这一点,我们需要检查 __enter__()__exit__() 函数,它们分别在 with 块的开始和结束处调用。

tensorflow_core/python/eager/backprop.py

def __enter__(self):
    """Enters a context inside which operations are recorded on this tape."""
    self._push_tape()
    return self

def __exit__(self, typ, value, traceback):
    """Exits the recording context, no further operations are traced."""
    if self._recording:
        self._pop_tape()

我们可以自己使用这些私有函数来控制录音,而不需要 with 块。

# Initialize outer and inner tapes
self.gt_outer = tf.GradientTape(persistent=True)
self.gt_inner = tf.GradientTape(persistent=True)

# Begin Recording
self.gt_outer._push_tape()
self.gt_inner._push_tape()

# evaluate loss which uses self.variables
loss_val = self.loss()

# stop recording on inner tape
self.gt_inner._pop_tape()

# Evaluate the gradient on the inner tape
self.gt_grad = self.gt_inner.gradient(loss_val, self.variables)

# Stop recording on the outer tape
self.gt_outer._pop_tape()

现在,每当我们需要评估 hessian 向量积时,我们都可以重复使用外部梯度带。

def hessian_v_prod(self, v):
    self.gt_outer._push_tape()
    v_hat = tf.reduce(tf.multiply(v, self.gt_grad))
    self.gt_outer._pop_tape()
    return self.gt_outer.gradient(v_hat, self.variables)

请注意,我们正在持久化磁带,因此每次计算 hessian 向量积时都会使用更多内存。无法保留部分磁带内存,因此在某些时候有必要重置磁带。

# reset tapes
self.gt_outer._tape = None
self.gt_inner._tape = None

在此之后要再次使用它们,我们需要重新评估内部循环。它并不完美,但它完成了工作并以更大的内存使用为代价提供了显着的加速(接近 x2)。