PyTorch 是否有等同于 tf.custom_gradient() 的东西?

Is there a PyTorch equivalent of tf.custom_gradient()?

我是 PyTorch 的新手,但对 TensorFlow 有很多经验。

我想修改一小部分图的梯度:只是单层激活函数的导数。这可以在 Tensorflow 中使用 tf.custom_gradient 轻松完成,它允许您为任何函数提供自定义梯度。

我想在 PyTorch 中做同样的事情,我知道你可以修改 backward() 方法,但这需要你重写 forward() 方法中定义的整个网络的导数,当我只想修改一小部分图形的梯度。 PyTorch 中有类似 tf.custom_gradient() 的东西吗?谢谢!

您可以通过两种方式做到这一点:

1.修改backward()函数:
正如您在问题中已经说过的, 还允许您提供自定义 backward 实现。然而,与你写的相反,你 not 需要 re-write 整个模型的 backward() - 只有 backward() 您要更改的特定层。
这是一个简单而不错的教程,展示了如何完成此操作。

例如,这里有一个自定义 clip 激活函数,它不会杀死 [0, 1] 域外的梯度,而是简单地传递梯度 as-is:

class MyClip(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.clip(x, 0., 1.)

    @staticmethod
    def backward(ctx, grad):
        return grad

现在您可以在模型中的任何地方使用 MyClip 层,而无需担心整体 backward 功能。


2。使用 backward 挂钩 允许您将挂钩附加到网络的不同层 (=sub nn.Modules)。您可以 register_full_backward_hook 到您的图层。该钩子函数可以修改渐变:

The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations.