为什么这个函数参数在每次调用中都是相同的,尽管传递了不同的值? (在循环中创建闭包)

Why is this function parameter identical in every invocation, despite passing different value? (Creating closures in a loop)

我正在使用 PyTorch 并尝试在模型参数上注册挂钩。下面的代码创建了 lambda 函数来添加到每个模型参数,所以我可以在 hook 中看到梯度属于哪个张量

import torch
import torchvision

# define model and random train batch
model = torchvision.models.alexnet()
input = torch.rand(10, 3, 224, 224)   # batch of 10 images
targets = torch.zeros(10).long()

def grad_hook_template(param, name, grad):
    print(f'Receive grad for {name} w whape {grad.shape}')

# add one lambda hook to each parameter
for name, param in model.named_parameters():
    print(f'Register hook for {name}')

    # use a lambda so we can pass additional information to the hook, which should only take one parameter
    param.register_hook(lambda grad: grad_hook_template(param, name, grad))

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()

prediction = model(input)
loss = loss_fn(prediction, targets)
loss.backward()
optimizer.step()

结果是 grad_hook_templatenameparam 参数总是相同的值(和 id),但是 grad 参数总是不同的(如预期的那样)。为什么当我注册钩子时,lambda 似乎每次都引用相同的局部变量?

我读过例如here 循环不会创建新的范围并且闭包在 Python 中是词法的,即我传递给 lambda 的 nameparam 只是指针和任何值他们在循环结束时被每个人看到这个指针。但是我能做些什么呢? copy.copy()变量?

这是 FAQ 的回答。

解决方案包括

  • 使用 functools.partial 而不是 lambda
  • 使用 lambda 的默认参数来捕获变量的值

您 运行 进入了 后期绑定闭包 。变量 paramname 在调用时查找,而不是在定义它们所用的函数时查找。在调用这些函数中的任何一个时,nameparam 处于循环中的最后一个值。要解决这个问题,您可以这样做:

for name, param in model.named_parameters():
    print(f'Register hook for {name}')
    param.register_hook(lambda grad, name=name, param=param: grad_hook_template(param, name, grad))

不过,我认为使用 functools.partial 是正确的解决方案:

from functools import partial

for name, param in model.named_parameters():
    print(f'Register hook for {name}')
    param.register_hook(partial(grad_hook_template, name=name, param=param))

您可以找到有关 late binding closures at the Common Gotchas page of the Hitchhiker's Guide to Python as well as in the Python docs 的更多信息。

请注意,这同样适用于使用 def 关键字定义的函数。