PyTorch 是否有等同于 tf.custom_gradient() 的东西?
Is there a PyTorch equivalent of tf.custom_gradient()?
我是 PyTorch 的新手,但对 TensorFlow 有很多经验。
我想修改一小部分图的梯度:只是单层激活函数的导数。这可以在 Tensorflow 中使用 tf.custom_gradient 轻松完成,它允许您为任何函数提供自定义梯度。
我想在 PyTorch 中做同样的事情,我知道你可以修改 backward() 方法,但这需要你重写 forward() 方法中定义的整个网络的导数,当我只想修改一小部分图形的梯度。 PyTorch 中有类似 tf.custom_gradient() 的东西吗?谢谢!
您可以通过两种方式做到这一点:
1.修改backward()
函数:
正如您在问题中已经说过的,pytorch 还允许您提供自定义 backward
实现。然而,与你写的相反,你 not 需要 re-write 整个模型的 backward()
- 只有 backward()
您要更改的特定层。
这是一个简单而不错的教程,展示了如何完成此操作。
例如,这里有一个自定义 clip
激活函数,它不会杀死 [0, 1]
域外的梯度,而是简单地传递梯度 as-is:
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.clip(x, 0., 1.)
@staticmethod
def backward(ctx, grad):
return grad
现在您可以在模型中的任何地方使用 MyClip
层,而无需担心整体 backward
功能。
2。使用 backward
挂钩
pytorch 允许您将挂钩附加到网络的不同层 (=sub nn.Module
s)。您可以 register_full_backward_hook
到您的图层。该钩子函数可以修改渐变:
The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input
in subsequent computations.
我是 PyTorch 的新手,但对 TensorFlow 有很多经验。
我想修改一小部分图的梯度:只是单层激活函数的导数。这可以在 Tensorflow 中使用 tf.custom_gradient 轻松完成,它允许您为任何函数提供自定义梯度。
我想在 PyTorch 中做同样的事情,我知道你可以修改 backward() 方法,但这需要你重写 forward() 方法中定义的整个网络的导数,当我只想修改一小部分图形的梯度。 PyTorch 中有类似 tf.custom_gradient() 的东西吗?谢谢!
您可以通过两种方式做到这一点:
1.修改backward()
函数:
正如您在问题中已经说过的,pytorch 还允许您提供自定义 backward
实现。然而,与你写的相反,你 not 需要 re-write 整个模型的 backward()
- 只有 backward()
您要更改的特定层。
这是一个简单而不错的教程,展示了如何完成此操作。
例如,这里有一个自定义 clip
激活函数,它不会杀死 [0, 1]
域外的梯度,而是简单地传递梯度 as-is:
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.clip(x, 0., 1.)
@staticmethod
def backward(ctx, grad):
return grad
现在您可以在模型中的任何地方使用 MyClip
层,而无需担心整体 backward
功能。
2。使用 backward
挂钩
pytorch 允许您将挂钩附加到网络的不同层 (=sub nn.Module
s)。您可以 register_full_backward_hook
到您的图层。该钩子函数可以修改渐变:
The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of
grad_input
in subsequent computations.