在调用 backward() 之前应用非 torch 函数处理损失?

Applying non-torch function on loss before calling backward()?

我想在计算梯度(调用 backward())之前对最终计算出的损失应用自定义非 torch 函数。一个例子是用自定义的 pythonic 非 torch 均值函数替换损失向量上的 torch.mean() 。但是这样做会破坏计算图。我无法使用 torch 运算符重写自定义均值函数,我不知道该怎么做。有什么建议吗?

在 pytorch 中,您可以通过继承 torch.autograd.Function 轻松地做到这一点:您需要做的就是实现您的自定义 forward() 和相应的 backward() 方法。因为我不知道您打算编写的函数,所以我将通过以适用于自动微分的方式实现正弦函数来对其进行演示。请注意,您需要有一种方法来计算函数相对于其输入的导数,以实现向后传递。

import torch

class MySin(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp):
        """ compute forward pass of custom function """
        ctx.save_for_backward(inp)  # save activation for backward pass
        return inp.sin()  # compute forward pass, can also be computed by any other library

    @staticmethod
    def backward(ctx, grad_out):
        """ compute product of output gradient with the 
        jacobian of your function evaluated at input """
        inp, = ctx.saved_tensors
        grad_inp = grad_out * torch.cos(inp)  # propagate gradient, can also be computed by any other library
        return grad_inp

要使用它,您可以在输入中使用函数 sin = MySin.apply

documentation.

中还有另一个例子