为什么这个函数参数在每次调用中都是相同的,尽管传递了不同的值? (在循环中创建闭包)
Why is this function parameter identical in every invocation, despite passing different value? (Creating closures in a loop)
我正在使用 PyTorch 并尝试在模型参数上注册挂钩。下面的代码创建了 lambda 函数来添加到每个模型参数,所以我可以在 hook 中看到梯度属于哪个张量
import torch
import torchvision
# define model and random train batch
model = torchvision.models.alexnet()
input = torch.rand(10, 3, 224, 224) # batch of 10 images
targets = torch.zeros(10).long()
def grad_hook_template(param, name, grad):
print(f'Receive grad for {name} w whape {grad.shape}')
# add one lambda hook to each parameter
for name, param in model.named_parameters():
print(f'Register hook for {name}')
# use a lambda so we can pass additional information to the hook, which should only take one parameter
param.register_hook(lambda grad: grad_hook_template(param, name, grad))
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
prediction = model(input)
loss = loss_fn(prediction, targets)
loss.backward()
optimizer.step()
结果是 grad_hook_template
的 name
和 param
参数总是相同的值(和 id
),但是 grad
参数总是不同的(如预期的那样)。为什么当我注册钩子时,lambda 似乎每次都引用相同的局部变量?
我读过例如here 循环不会创建新的范围并且闭包在 Python 中是词法的,即我传递给 lambda 的 name
和 param
只是指针和任何值他们在循环结束时被每个人看到这个指针。但是我能做些什么呢? copy.copy()
变量?
这是 FAQ 的回答。
解决方案包括
- 使用
functools.partial
而不是 lambda
- 使用 lambda 的默认参数来捕获变量的值
您 运行 进入了 后期绑定闭包 。变量 param
和 name
在调用时查找,而不是在定义它们所用的函数时查找。在调用这些函数中的任何一个时,name
和 param
处于循环中的最后一个值。要解决这个问题,您可以这样做:
for name, param in model.named_parameters():
print(f'Register hook for {name}')
param.register_hook(lambda grad, name=name, param=param: grad_hook_template(param, name, grad))
不过,我认为使用 functools.partial
是正确的解决方案:
from functools import partial
for name, param in model.named_parameters():
print(f'Register hook for {name}')
param.register_hook(partial(grad_hook_template, name=name, param=param))
您可以找到有关 late binding closures at the Common Gotchas page of the Hitchhiker's Guide to Python as well as in the Python docs 的更多信息。
请注意,这同样适用于使用 def
关键字定义的函数。
我正在使用 PyTorch 并尝试在模型参数上注册挂钩。下面的代码创建了 lambda 函数来添加到每个模型参数,所以我可以在 hook 中看到梯度属于哪个张量
import torch
import torchvision
# define model and random train batch
model = torchvision.models.alexnet()
input = torch.rand(10, 3, 224, 224) # batch of 10 images
targets = torch.zeros(10).long()
def grad_hook_template(param, name, grad):
print(f'Receive grad for {name} w whape {grad.shape}')
# add one lambda hook to each parameter
for name, param in model.named_parameters():
print(f'Register hook for {name}')
# use a lambda so we can pass additional information to the hook, which should only take one parameter
param.register_hook(lambda grad: grad_hook_template(param, name, grad))
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
prediction = model(input)
loss = loss_fn(prediction, targets)
loss.backward()
optimizer.step()
结果是 grad_hook_template
的 name
和 param
参数总是相同的值(和 id
),但是 grad
参数总是不同的(如预期的那样)。为什么当我注册钩子时,lambda 似乎每次都引用相同的局部变量?
我读过例如here 循环不会创建新的范围并且闭包在 Python 中是词法的,即我传递给 lambda 的 name
和 param
只是指针和任何值他们在循环结束时被每个人看到这个指针。但是我能做些什么呢? copy.copy()
变量?
这是 FAQ 的回答。
解决方案包括
- 使用
functools.partial
而不是lambda
- 使用 lambda 的默认参数来捕获变量的值
您 运行 进入了 后期绑定闭包 。变量 param
和 name
在调用时查找,而不是在定义它们所用的函数时查找。在调用这些函数中的任何一个时,name
和 param
处于循环中的最后一个值。要解决这个问题,您可以这样做:
for name, param in model.named_parameters():
print(f'Register hook for {name}')
param.register_hook(lambda grad, name=name, param=param: grad_hook_template(param, name, grad))
不过,我认为使用 functools.partial
是正确的解决方案:
from functools import partial
for name, param in model.named_parameters():
print(f'Register hook for {name}')
param.register_hook(partial(grad_hook_template, name=name, param=param))
您可以找到有关 late binding closures at the Common Gotchas page of the Hitchhiker's Guide to Python as well as in the Python docs 的更多信息。
请注意,这同样适用于使用 def
关键字定义的函数。