如何在 Pytorch 中注册张量的动态反向挂钩?

How to register a dynamic backward hook on tensors in Pytorch?

我正在尝试在网络中的每个神经元 权重上注册一个反向钩子。我所说的动态是指它将取一个值并将相关的梯度乘以该值。

here it seem like it's possible to register a hook on a tensor with a fixed value (though note that I need it to take a value that will change). From here 看来也可以在所有参数上注册一个钩子——他们用它来进行梯度裁剪(尽管请注意,我试图只在每个神经元的权重)。

如果我的网络如下:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.fc1 = nn.Linear(3,5)
        self.fc2 = nn.Linear(5,10)
        self.fc3 = nn.Linear(10,1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        return x 

第一层有 5 个神经元,每个神经元有 3 个相关权重。因此,这一层应该有 5 个钩子,在向后步骤中修改(即通过乘以它来改变当前梯度)它们的 3 个相关权重梯度。

训练伪代码示例:

net = Model()
for epoch in epochs:
    out = net(data)
    loss = criterion(out, target)
    optimizer.zero_grad()
    loss.backward()
    for hook in list_of_hooks: #not sure if there's a more "pytorch" way of doing this without a for loop
        hook(random_value)
    optimizer.step()

如何利用 lambdas closure over names

一个简短的例子:

import torch

net_params = torch.rand(5, 3, requires_grad=True)

msg = "Hello!"

t.register_hook(lambda g: print(msg))


out1 = net_params * 2.

loss = out1.sum()
loss.backward()  # Activates the hook and prints "Hello!"


msg = "How are you?"  # The lambda is affected by this change

out2 = t ** 4.
loss2 = out2.sum()

loss2.backward()  # Activates the hook again and prints "How are you?"

你的问题的可能解决方案:

net = Model()
# Replace it with your computed values
rand_values = torch.rand(net.fc1.out_features, net.fc1.in_features)

net.fc1.weight.register_hook(lambda g: g * rand_values) 

for epoch in epochs:
    out = net(data)
    loss = criterion(out, target)
    optimizer.zero_grad()
    loss.backward()  # fc1 gradients are multiplied by rand_values
    optimizer.step()

    # Update rand_values. The lambda computation will change accordingly
    rand_values = torch.rand(net.fc1.out_features, net.fc1.in_features)

编辑

为了让事情更清楚,如果您特别想将每组权重 i 乘以单个值 vi,您可以利用 broadcasting semantic 并定义 values = torch.tensor([v0, v1, v2, v3, v4]).reshape(5, 1),然后 lambda 变成 lambda g: g * values