如何从 PyTorch Transformer 中的中间编码器层获取输出?

How to get output from intermediate encoder layers in PyTorch Transformer?

我训练了一个相当简单的具有 6 个 TransformerEncoder 层的 Transformer 模型:

class LitModel(pl.LightningModule):
    def __init__(self,
                 num_tokens: int,
                 dim_model: int = 96,
                 dim_h: int = 128,
                 n_head: int = 1,
                 dropout: float = 0.1,
                 activation: str = 'relu',
                 num_layers: int = 2,
                 lr: float=1e-3):
        """

        :param num_tokens:
        :param dim_model:
        :param dim_h:
        :param n_head:
        :param dropout:
        :param activation:
        :param num_layers:
        """
        super().__init__()
        self.lr = lr
        self.embed = torch.nn.Embedding(num_embeddings=num_tokens,
                                        embedding_dim=dim_model)
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=dim_model,
                                                         nhead=n_head,
                                                         dim_feedforward=dim_h,
                                                         dropout=dropout,
                                                         activation=activation,
                                                         batch_first=True)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer=encoder_layer,
                                                   num_layers=num_layers)
        self.linear = torch.nn.Linear(in_features=dim_model, out_features=num_tokens)

    def forward(self, indices, mask):
        x = self.embed(indices)
        x = self.encoder(x, src_key_padding_mask=mask)
        return x

    def training_step(self, batch, batch_idx):
        x = batch['src']
        y = batch['label']
        mask = batch['mask']

        x = self.embed(x)
        x = self.encoder(x, src_key_padding_mask=mask)
        x = self.linear(x)

        loss = F.cross_entropy(input=x.transpose(1, 2),
                               target=y,
                               ignore_index=0)
        self.log('train_loss', loss)
        return loss

训练模型预测 [MASK] 标记(完全像 BERT)后,我希望能够从较低层提取输出,具体来说,倒数第二个 TransformerEncoderLayer,这可能给出比最后一层更好的矢量编码(根据原始 BERT 论文)。我不确定该怎么做。

以防万一评论不清楚,您可以通过注册 forward hook:

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

# instantiate the model
model = LitModel(...)

# register the forward hook
model.encoder.layers[-2].register_forward_hook(get_activation('encoder_penultimate_layer'))

# pass some data through the model
output = model(x)

# this is what you're looking for
activation['encoder_penultimate_layer']