如何创建带条件的 PyTorch 挂钩?

How to create a PyTorch hook with conditions?

我正在学习钩子并使用二值化神经网络。问题是有时我的梯度在向后传递中为 0。我正在尝试用某个值替换这些渐变。

假设我有以下网络

import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 2)
        self.fc2 = nn.Linear(2, 3)
        self.fc3 = nn.Linear(3, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)        
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Model()

opt = optim.Adam(net.parameters())

还有一些功能

features = torch.rand((3,1))

我可以正常训练它:

for i in range(10):
    opt.zero_grad()
    out = net(features)
    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
    loss.backward()
    opt.step()

如何附加一个钩子函数,该函数将具有以下条件用于向后传递(对于每一层):

您可以使用 nn.Module.register_full_backward_hook:

nn.Module 上附加回调函数

您将不得不处理这两种情况:如果使用 torch.all, else (i.e. at least one is non zero) if at least one is equal to zero using torch.any 所有元素都为零。

def grad_mod(module, grad_inputs, grad_outputs):
    if module.weight.grad is None: # safety measure for last layer 
        return None                # and layers w/ require_grad=False

    flat = module.weight.grad.view(-1)
    if torch.all(flat == 0):
        flat.data.fill_(1.)
    elif torch.any(flat == 0):
        flat.data.scatter_(0, (flat == 0).nonzero()[:,0], value=.5)

第一个子句中的指令会将所有值填充为 1.,而第二个子句中的指令只会将零值替换为 .5

将挂钩挂在 nn.Module:

>>> net.fc3.register_full_backward_hook(grad_mod)

这里我在变异flat之前和之后使用print语句来展示钩子的效果:

>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0.0947, 0.0000, 0.0000]) # before
>>> tensor([0.0947, 0.5000, 0.5000]) # after

>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0., 0., 0.])             # before
>>> tensor([1., 1., 1.])             # after

为了将此挂钩应用到多层,您可以包装 grad_mod 并利用 nn.Module.apply 递归行为:

>>> def apply_grad_mod(module):
...     if hasattr(module, 'weight'):
...         module.register_full_backward_hook(grad_mod)

然后下面将在所有图层权重上应用挂钩。

>>> net.apply(apply_grad_mod)

注意:如果您还想影响偏差,则必须扩展此行为!