当我们需要反向传播函数两次时,如何避免重新计算函数?

How to avoid recalculating a function when we need to backpropagate through it twice?

在 PyTorch 中,我想做以下计算:

l1 = f(x.detach(), y)
l1.backward(retain_graph=True)
l2 = -1*f(x, y.detach())
l2.backward()

其中f是一些函数,xy是需要梯度的张量。请注意 xy 可能都是先前使用共享参数计算的结果(例如,可能 x=g(z)y=g(w) 其中 gnn.Module).

问题是 l1l2 在数值上完全相同,直到负号为止,重复计算 f(x,y) 两次似乎很浪费。如果能够计算一次,然后对结果应用 backward 两次,那就更好了。有什么办法吗?

一种可能是手动调用 autograd.grad 并更新每个 nn.Parameter ww.grad 字段。但我想知道是否有更直接、更简洁的方法来执行此操作,即使用 backward 函数。

我从 here 那里得到了这个答案。

我们可以计算一次f(x,y),而不分离xy,如果我们确保我们乘以-1流经[=13的梯度=].这可以使用 register_hook:

来完成
x.register_hook(lambda t: -t)
l = f(x,y)
l.backward()

这里是证明这是可行的代码:

import torch

lin = torch.nn.Linear(1, 1, bias=False)
lin.weight.data[:] = 1.0
a = torch.tensor([1.0])
b = torch.tensor([2.0])
loss_func = lambda x, y: (x - y).abs()

# option 1: this is the inefficient option, presented in the original question
lin.zero_grad()
x = lin(a)
y = lin(b)
loss1 = loss_func(x.detach(), y)
loss1.backward(retain_graph=True)
loss2 = -1 * loss_func(x, y.detach())  # second invocation of `loss_func` - not efficient!
loss2.backward()
print(lin.weight.grad)

# option 2: this is the efficient method, suggested in this answer. 
lin.zero_grad()
x = lin(a)
y = lin(b)
x.register_hook(lambda t: -t)
loss = loss_func(x, y)  # only one invocation of `loss_func` - more efficient!
loss.backward()
print(lin.weight.grad)  # the output of this is identical to the previous print, which confirms the method

# option 3 - this should not be equivalent to the previous options, used just for comparison
lin.zero_grad()
x = lin(a)
y = lin(b)
loss = loss_func(x, y)
loss.backward()
print(lin.weight.grad)