如何在pytorch中为nn.Transformer写一个forward hook函数?

How to write a forward hook function for nn.Transformer in pytorch?

我了解到 forward hook 函数的形式为 hook_fn(m,x,y)。 m 指模型,x 指输入,y 指输出。我想为 nn.Transformer 编写一个前向钩子函数。
然而,对于 transformer 层有输入,即 src 和 tgt。例如,>>> out = transformer_model(src, tgt)。那么我怎样才能区分这些输入呢?

您的挂钩将使用 元组 xy 调用您的回调函数。正如 torch.nn.Module.register_forward_hook 的文档页面中所述(它确实很好地解释了 xy 的类型,尽管 )。

The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. [...].

model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)

定义回调:

def hook(module, x, y):
    print(f'is tuple={isinstance(x, tuple)} - length={len(x)}')      
    src, tgt = x
  
    print(f'src: {src.shape}')
    print(f'tgt: {tgt.shape}')

挂钩到您的 nn.Module:

>>> model.register_forward_hook(hook)

做一个推理:

>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])