如何在pytorch中为nn.Transformer写一个forward hook函数?
How to write a forward hook function for nn.Transformer in pytorch?
我了解到 forward hook 函数的形式为 hook_fn(m,x,y)
。 m 指模型,x 指输入,y 指输出。我想为 nn.Transformer
编写一个前向钩子函数。
然而,对于 transformer 层有输入,即 src 和 tgt。例如,>>> out = transformer_model(src, tgt)
。那么我怎样才能区分这些输入呢?
您的挂钩将使用 元组 为 x
和 y
调用您的回调函数。正如 torch.nn.Module.register_forward_hook
的文档页面中所述(它确实很好地解释了 x
和 y
的类型,尽管 )。
The input contains only the positional arguments given to the module.
Keyword arguments won’t be passed to the hooks and only to the
forward. [...].
model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
定义回调:
def hook(module, x, y):
print(f'is tuple={isinstance(x, tuple)} - length={len(x)}')
src, tgt = x
print(f'src: {src.shape}')
print(f'tgt: {tgt.shape}')
挂钩到您的 nn.Module
:
>>> model.register_forward_hook(hook)
做一个推理:
>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])
我了解到 forward hook 函数的形式为 hook_fn(m,x,y)
。 m 指模型,x 指输入,y 指输出。我想为 nn.Transformer
编写一个前向钩子函数。
然而,对于 transformer 层有输入,即 src 和 tgt。例如,>>> out = transformer_model(src, tgt)
。那么我怎样才能区分这些输入呢?
您的挂钩将使用 元组 为 x
和 y
调用您的回调函数。正如 torch.nn.Module.register_forward_hook
的文档页面中所述(它确实很好地解释了 x
和 y
的类型,尽管 )。
The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. [...].
model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
定义回调:
def hook(module, x, y):
print(f'is tuple={isinstance(x, tuple)} - length={len(x)}')
src, tgt = x
print(f'src: {src.shape}')
print(f'tgt: {tgt.shape}')
挂钩到您的 nn.Module
:
>>> model.register_forward_hook(hook)
做一个推理:
>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])