BertModel 或 BertForPreTraining
BertModel or BertForPreTraining
我只想将 Bert 用于嵌入,并将 Bert 输出用作我将从头开始构建的分类网络的输入。
我不确定是否要对模型进行微调。
我认为相关的 类 是 BertModel 或 BertForPreTraining。
BertForPreTraining head 包含两个“动作”:
self.predictions是MLM(Masked Language Modeling)头,赋予BERT修复语法错误的能力,self.seq_relationship是NSP(Next Sentence Prediction);通常被称为分类头。
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
我认为 NSP 与我的任务无关,因此我可以“覆盖”它。
MLM 的作用是什么?它与我的目标相关吗?还是我应该使用 BertModel?
您应该使用 BertModel
而不是 BertForPreTraining
。
BertForPreTraining
用于在 Masked Language Model (MLM) 和 Next Sentence Prediction (NSP) 任务上训练 bert。它们不用于分类。
BERT 模型只是给出 BERT 模型的输出,然后您可以微调 BERT 模型以及您在其之上构建的分类器。对于分类,如果它只是在 BERT 模型之上的单层,则可以直接使用 BertForSequenceClassification
。
无论如何,如果您只想获取 BERT 模型的输出并学习您的分类器(无需微调 BERT 模型),那么您可以使用以下方法冻结 BERT 模型权重:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
for param in model.bert.bert.parameters():
param.requires_grad = False
以上代码借鉴自here
我只想将 Bert 用于嵌入,并将 Bert 输出用作我将从头开始构建的分类网络的输入。
我不确定是否要对模型进行微调。
我认为相关的 类 是 BertModel 或 BertForPreTraining。
BertForPreTraining head 包含两个“动作”: self.predictions是MLM(Masked Language Modeling)头,赋予BERT修复语法错误的能力,self.seq_relationship是NSP(Next Sentence Prediction);通常被称为分类头。
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
我认为 NSP 与我的任务无关,因此我可以“覆盖”它。 MLM 的作用是什么?它与我的目标相关吗?还是我应该使用 BertModel?
您应该使用 BertModel
而不是 BertForPreTraining
。
BertForPreTraining
用于在 Masked Language Model (MLM) 和 Next Sentence Prediction (NSP) 任务上训练 bert。它们不用于分类。
BERT 模型只是给出 BERT 模型的输出,然后您可以微调 BERT 模型以及您在其之上构建的分类器。对于分类,如果它只是在 BERT 模型之上的单层,则可以直接使用 BertForSequenceClassification
。
无论如何,如果您只想获取 BERT 模型的输出并学习您的分类器(无需微调 BERT 模型),那么您可以使用以下方法冻结 BERT 模型权重:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
for param in model.bert.bert.parameters():
param.requires_grad = False
以上代码借鉴自here