防止更新卷积权重矩阵的特定元素

Prevent Updating for Specific Element of Convolutional Weight Matrix

我正在尝试将权重的一个元素设置为 1,然后保持不变直到学习结束(防止它在下一个 epoch 中更新)。我知道我可以设置 requires_grad = False 但我只想对一个元素而不是所有元素进行此过程。

您可以在 nn.Module 上附加一个反向钩子,这样在反向传播期间您可以将感兴趣的元素覆盖到 0。这确保它的值永远不会改变,而不会阻止梯度向输入的反向传播。

后向钩子的新 API 是 nn.Module.register_full_backward_hook。首先构造一个回调函数,作为layer hook:

def freeze_single(index):
    def callback(module, grad_input, grad_output):
        module.weight.grad.data[index] = 0
    return callback

然后,我们可以将这个钩子附加到任何 nn.Module。例如,在这里我决定冻结卷积层的组件 [0, 1, 2, 1]

>>> conv = nn.Conv2d(3, 1, 3)
>>> conv.weight.data[0, 1, 2, 1] = 1

>>> conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))

一切设置正确,让我们试试:

>>> x = torch.rand(1, 3, 10, 10, requires_grad=True)
>>> conv(x).mean().backward()

这里我们可以验证分量[0, 1, 2, 1]的梯度确实等于0:

>>> conv.weight.grad
tensor([[[[0.4954, 0.4776, 0.4639],
          [0.5179, 0.4992, 0.4856],
          [0.5271, 0.5219, 0.5124]],

         [[0.5367, 0.5035, 0.5009],
          [0.5703, 0.5390, 0.5207],
          [0.5422, 0.0000, 0.5109]], # <-

         [[0.4937, 0.5150, 0.5200],
          [0.4817, 0.5070, 0.5241],
          [0.5039, 0.5295, 0.5445]]]])

您可以随时 detach/reattach 挂钩:

>>> hook = conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))
>>> hook.remove()

不要忘记,如果您移除钩子,该组件的值会在您更新权重时发生变化。如果您愿意,您必须将其重置为 1。否则,您可以实施第二个挂钩 - 这次是 register_forward_pre_hook 挂钩 - 来处理它。

Ivan 已经谈到使用向后钩子来覆盖所需元素的渐变。

另一种不用钩子的方法是在前向传递之前用值覆盖所需的参数。

假设您只想用 1 覆盖线性层的 0,0 元素,您可以这样做

def forward(x)
    model.weight[0,0] = 1
    # usual forward pass    

当你向后传递时,元素会由于通常的梯度更新而更新,但在下一次向前传递时,它会再次被 1 覆盖,并在所有训练计算中保持该值。这也可以通过forward前的hook来实现。