BertModel 或 BertForPreTraining

BertModel or BertForPreTraining

我只想将 Bert 用于嵌入,并将 Bert 输出用作我将从头开始构建的分类网络的输入。

我不确定是否要对模型进行微调。

我认为相关的 类 是 BertModel 或 BertForPreTraining。

BertForPreTraining head 包含两个“动作”: self.predictions是MLM(Masked Language Modeling)头,赋予BERT修复语法错误的能力,self.seq_relationship是NSP(Next Sentence Prediction);通常被称为分类头。

class BertPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

我认为 NSP 与我的任务无关,因此我可以“覆盖”它。 MLM 的作用是什么?它与我的目标相关吗?还是我应该使用 BertModel?

您应该使用 BertModel 而不是 BertForPreTraining

BertForPreTraining 用于在 Masked Language Model (MLM) 和 Next Sentence Prediction (NSP) 任务上训练 bert。它们不用于分类。

BERT 模型只是给出 BERT 模型的输出,然后您可以微调 BERT 模型以及您在其之上构建的分类器。对于分类,如果它只是在 BERT 模型之上的单层,则可以直接使用 BertForSequenceClassification

无论如何,如果您只想获取 BERT 模型的输出并学习您的分类器(无需微调 BERT 模型),那么您可以使用以下方法冻结 BERT 模型权重:

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

for param in model.bert.bert.parameters():
    param.requires_grad = False

以上代码借鉴自here