pytorch 模型上最小 hvp 的问题
Trouble with minimal hvp on pytorch model
虽然 autograd 的 hvp 工具似乎对函数非常有效,但一旦涉及模型,Hessian-vector 乘积似乎变为 0。一些代码。
首先,我定义了世界上最简单的模型:
class SimpleMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, out_dim),
)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
然后,一个损失函数:
def objective(x):
return torch.sum(0.25 * torch.sum(x)**4)
我们实例化它:
Arows = 2
Acols = 2
mlp = SimpleMLP(Arows, Acols)
最后,我将定义一个“前向”函数(不同于模型的前向函数),它将作为我们要分析的完整模型+损失:
def forward(*params_list):
for param_val, model_param in zip(params_list, mlp.parameters()):
model_param.data = param_val
x = torch.ones((Arows,))
return objective(mlp(x))
这会将一个向量传递到单层“mlp”中,并将其传递到我们的二次损失中。
现在,我尝试计算:
v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
numel = param.numel()
v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
idx += numel
最后:
param_tensors = tuple(mlp.parameters())
reshaped_v = tuple(v_tensors)
soln = torch.autograd.functional.hvp(forward, param_tensors, v=reshaped_v)
但是,唉,soln中的Hessian-Vector Product全为0。发生了什么事?
你试过用双打而不是花车吗?我自己做了一些测试,与双精度相比,使用 32 位浮点数(大约 1e-5)进行反向传播时,显示出相当大的错误。
发生的事情是 strict
在 hvp()
函数中默认为 False,并且张量 0 作为 Hessian 向量积返回而不是错误 (source)。
如果您尝试使用 strict=True
,则会返回错误 RuntimeError: The output of the user-provided function is independent of input 0. This is not allowed in strict mode.
。当我查看完整错误时,我怀疑此错误来自 _check_requires_grad(jac, "jacobian", strict=strict)
,这表明雅可比 jac
是 None
.
更新:
以下是一个完整的工作示例:
import torch
from torch import nn
# your loss function
def objective(x):
return torch.sum(0.25 * torch.sum(x)**4)
# Following are utilities to make nn.Module functional
# borrowed from the link I posted in comment
def del_attr(obj, names):
if len(names) == 1:
delattr(obj, names[0])
else:
del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
if len(names) == 1:
setattr(obj, names[0], val)
else:
set_attr(getattr(obj, names[0]), names[1:], val)
def make_functional(mod):
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
del_attr(mod, name.split("."))
names.append(name)
return orig_params, names
def load_weights(mod, names, params):
for name, p in zip(names, params):
set_attr(mod, name.split("."), p)
# your forward function with update
def forward(*new_params):
# this line replace your for loop
load_weights(mlp, names, new_params)
x = torch.ones((Arows,))
out = mlp(x)
loss = objective(out)
return loss
# your simple MLP model
class SimpleMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, out_dim),
)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
if __name__ == '__main__':
# your model instantiation
Arows = 2
Acols = 2
mlp = SimpleMLP(Arows, Acols)
# your vector computation
v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
numel = param.numel()
v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
idx += numel
reshaped_v = tuple(v_tensors)
#make model's parameters functional
params, names = make_functional(mlp)
params = tuple(p.detach().requires_grad_() for p in params)
#compute hvp
soln = torch.autograd.functional.vhp(forward, params, reshaped_v, strict=True)
print(soln)
虽然 autograd 的 hvp 工具似乎对函数非常有效,但一旦涉及模型,Hessian-vector 乘积似乎变为 0。一些代码。
首先,我定义了世界上最简单的模型:
class SimpleMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, out_dim),
)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
然后,一个损失函数:
def objective(x):
return torch.sum(0.25 * torch.sum(x)**4)
我们实例化它:
Arows = 2
Acols = 2
mlp = SimpleMLP(Arows, Acols)
最后,我将定义一个“前向”函数(不同于模型的前向函数),它将作为我们要分析的完整模型+损失:
def forward(*params_list):
for param_val, model_param in zip(params_list, mlp.parameters()):
model_param.data = param_val
x = torch.ones((Arows,))
return objective(mlp(x))
这会将一个向量传递到单层“mlp”中,并将其传递到我们的二次损失中。
现在,我尝试计算:
v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
numel = param.numel()
v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
idx += numel
最后:
param_tensors = tuple(mlp.parameters())
reshaped_v = tuple(v_tensors)
soln = torch.autograd.functional.hvp(forward, param_tensors, v=reshaped_v)
但是,唉,soln中的Hessian-Vector Product全为0。发生了什么事?
你试过用双打而不是花车吗?我自己做了一些测试,与双精度相比,使用 32 位浮点数(大约 1e-5)进行反向传播时,显示出相当大的错误。
发生的事情是 strict
在 hvp()
函数中默认为 False,并且张量 0 作为 Hessian 向量积返回而不是错误 (source)。
如果您尝试使用 strict=True
,则会返回错误 RuntimeError: The output of the user-provided function is independent of input 0. This is not allowed in strict mode.
。当我查看完整错误时,我怀疑此错误来自 _check_requires_grad(jac, "jacobian", strict=strict)
,这表明雅可比 jac
是 None
.
更新:
以下是一个完整的工作示例:
import torch
from torch import nn
# your loss function
def objective(x):
return torch.sum(0.25 * torch.sum(x)**4)
# Following are utilities to make nn.Module functional
# borrowed from the link I posted in comment
def del_attr(obj, names):
if len(names) == 1:
delattr(obj, names[0])
else:
del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
if len(names) == 1:
setattr(obj, names[0], val)
else:
set_attr(getattr(obj, names[0]), names[1:], val)
def make_functional(mod):
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
del_attr(mod, name.split("."))
names.append(name)
return orig_params, names
def load_weights(mod, names, params):
for name, p in zip(names, params):
set_attr(mod, name.split("."), p)
# your forward function with update
def forward(*new_params):
# this line replace your for loop
load_weights(mlp, names, new_params)
x = torch.ones((Arows,))
out = mlp(x)
loss = objective(out)
return loss
# your simple MLP model
class SimpleMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, out_dim),
)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
if __name__ == '__main__':
# your model instantiation
Arows = 2
Acols = 2
mlp = SimpleMLP(Arows, Acols)
# your vector computation
v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
numel = param.numel()
v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
idx += numel
reshaped_v = tuple(v_tensors)
#make model's parameters functional
params, names = make_functional(mlp)
params = tuple(p.detach().requires_grad_() for p in params)
#compute hvp
soln = torch.autograd.functional.vhp(forward, params, reshaped_v, strict=True)
print(soln)