pytorch 模型上最小 hvp 的问题

Trouble with minimal hvp on pytorch model

虽然 autograd 的 hvp 工具似乎对函数非常有效,但一旦涉及模型,Hessian-vector 乘积似乎变为 0。一些代码。

首先,我定义了世界上最简单的模型:

class SimpleMLP(nn.Module):
  def __init__(self, in_dim, out_dim):
      super().__init__()
      self.layers = nn.Sequential(
        nn.Linear(in_dim, out_dim),
      )
      
  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)

然后,一个损失函数:

def objective(x):
  return torch.sum(0.25 * torch.sum(x)**4)

我们实例化它:

Arows = 2
Acols = 2

mlp = SimpleMLP(Arows, Acols)

最后,我将定义一个“前向”函数(不同于模型的前向函数),它将作为我们要分析的完整模型+损失:

def forward(*params_list):
  for param_val, model_param in zip(params_list, mlp.parameters()):
    model_param.data = param_val
 
  x = torch.ones((Arows,))
  return objective(mlp(x))

这会将一个向量传递到单层“mlp”中,并将其传递到我们的二次损失中。

现在,我尝试计算:

v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
  numel = param.numel()
  v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
  idx += numel

最后:

param_tensors = tuple(mlp.parameters())
reshaped_v = tuple(v_tensors)
soln =  torch.autograd.functional.hvp(forward, param_tensors, v=reshaped_v)

但是,唉,soln中的Hessian-Vector Product全为0。发生了什么事?

你试过用双打而不是花车吗?我自己做了一些测试,与双精度相比,使用 32 位浮点数(大约 1e-5)进行反向传播时,显示出相当大的错误。

发生的事情是 stricthvp() 函数中默认为 False,并且张量 0 作为 Hessian 向量积返回而不是错误 (source)。

如果您尝试使用 strict=True,则会返回错误 RuntimeError: The output of the user-provided function is independent of input 0. This is not allowed in strict mode.。当我查看完整错误时,我怀疑此错误来自 _check_requires_grad(jac, "jacobian", strict=strict),这表明雅可比 jacNone.

更新:

以下是一个完整的工作示例:

import torch
from torch import nn

# your loss function
def objective(x):
    return torch.sum(0.25 * torch.sum(x)**4)

# Following are utilities to make nn.Module functional
# borrowed from the link I posted in comment
def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])

def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(mod):
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)

# your forward function with update
def forward(*new_params):
    # this line replace your for loop
    load_weights(mlp, names, new_params)

    x = torch.ones((Arows,))
    out = mlp(x)
    loss = objective(out)
    return loss

# your simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, out_dim),
        )

    def forward(self, x):
        '''Forward pass'''
        return self.layers(x)


if __name__ == '__main__':
    # your model instantiation
    Arows = 2
    Acols = 2
    mlp = SimpleMLP(Arows, Acols)

    # your vector computation
    v = torch.ones((6,))
    v_tensors = []
    idx = 0
    #this code "reshapes" the v vector as needed
    for i, param in enumerate(mlp.parameters()):
        numel = param.numel()
        v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
        idx += numel
    reshaped_v = tuple(v_tensors)

    #make model's parameters functional
    params, names = make_functional(mlp)
    params = tuple(p.detach().requires_grad_() for p in params)

    #compute hvp
    soln = torch.autograd.functional.vhp(forward, params, reshaped_v, strict=True)
    print(soln)