PyTorch 在损失函数中使用 autograd 时不更新权重

PyTorch not updating weights when using autograd in loss function

我正在尝试使用网络相对于其输入的梯度作为损失函数的一部分。但是,每当我尝试计算它时,训练都会进行但权重不会更新

import torch
import torch.optim as optim
import torch.autograd as autograd


ic = torch.rand((25, 3))
ic = torch.tensor(ic, requires_grad=True)
optimizer = optim.RMSprop([ic], lr=1e-2)

for itr in range(1, 50):
    optimizer.zero_grad()
    sol = torch.tanh(.5*torch.stack(100*[ic])) # simplified for minimal working example
    
    dx = sol[-1, :, 0]
    dxdxy, = autograd.grad(dx, 
                           inputs=ic,
                           grad_outputs = torch.ones(ic.shape[0]), # batchwise
                           retain_graph=True
                          )
    dxdxy = torch.tensor(dxdxy, requires_grad=True)
    loss = torch.sum(dxdxy)
    
    loss.backward()
    optimizer.step()
    
    if itr % 5 == 0:
        print(loss)

我做错了什么?

当你 运行 autograd.grad 没有设置标志 create_graphTrue 那么你将不会获得连接到计算图的输出,这意味着您将无法进一步优化 w.r.t ic(并获得您希望在此处执行的高阶导数)。 来自 torch.autograd.grad 的文档字符串:

create_graph (bool, optional): If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: False.

使用 dxdxy = torch.tensor(dxdxy, requires_grad=True) 正如你在这里尝试的那样不会有帮助,因为连接到 ic 的计算图到那时已经丢失(因为 create_graphFalse),你所做的就是创建一个新的计算图,其中 dxdxy 是叶节点。

请参阅下面的解决方案(请注意,当您创建 ic 时,您可以设置 requires_grad=True 因此第二行是多余的(这不是一个逻辑问题,只是更长的代码):

import torch
import torch.optim as optim
import torch.autograd as autograd

ic = torch.rand((25, 3),requires_grad=True) #<-- requires_grad to True here
#ic = torch.tensor(ic, requires_grad=True) #<-- redundant
optimizer = optim.RMSprop([ic], lr=1e-2)

for itr in range(1, 50):
    optimizer.zero_grad()
    sol = torch.tanh(.5 * torch.stack(100 * [ic]))  # simplified for minimal working example

    dx = sol[-1, :, 0]
    dxdxy, = autograd.grad(dx,
                           inputs=ic,
                           grad_outputs=torch.ones(ic.shape[0]),  # batchwise
                           retain_graph=True, create_graph=True # <-- important
                           )
    #dxdxy = torch.tensor(dxdxy, requires_grad=True) #<-- won't do the trick. Remove
    loss = torch.sum(dxdxy)

    loss.backward()
    optimizer.step()

    if itr % 5 == 0:
        print(loss)