从每个输入的 bert 模型中提取并连接最后 4 个隐藏状态

extract and concanate the last 4 hidden states from bert model for each input

我想为每个输入句子从 bert 中提取并连接 4 个最后的隐藏状态并保存它们 我使用这段代码,但我只得到了最后的隐藏状态

class MixModel(nn.Module):
    def __init__(self,pre_trained='bert-base-uncased'):
        super().__init__()        
        self.bert =  AutoModel.from_pretrained('distilbert-base-uncased')
        self.hidden_size = self.bert.config.hidden_size
        
      
           
    def forward(self,inputs, mask , labels):
        
        cls_hs = self.bert(input_ids=inputs,attention_mask=mask, return_dict= False,  output_hidden_states=True)        
        print(cls_hs)        
                   
        encoded_layers = cls_hs[0]
        print(len(encoded_layers))

        print(encoded_layers.size())
        #output is [1,64,768]
       
        return encoded_layers

批量大小为 1 填充大小为 64

如何提取后四位?

获取最后 4 个隐藏状态,现在是 4 个形状张量的元组 (batch_size、seq_len、hidden_size)

encoded_layers = cls_hs['hidden_states'][-4:]

并将它们(这里是最后一个维度)连接成一个形状为 (batch_size, seq_len, 4 * hidden_size)

的张量
concatenated = torch.cat(encoded_layers, -1)