如何冻结 TFBertForSequenceClassification 预训练模型?

How to freeze TFBertForSequenceClassification pre trained model?

如果我使用的是 tensorflow 版本的 huggingface transformer,我该如何冻结预训练编码器的权重,以便仅优化头层的权重?

对于PyTorch的实现,是通过

完成的
for param in model.base_model.parameters():
    param.requires_grad = False

想为 tensorflow 实现做同样的事情。

找到方法了。在编译之前冻结基础模型。

model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
model.layers[0].trainable = False
model.compile(...)

或者:

model.bert.trainable = False

在挖掘这个线程 1 之后,我认为以下代码不会对 TF2 造成伤害。即使在特定情况下它可能是多余的。

 model = TFBertModel.from_pretrained('./bert-base-uncase')
 for layer in model.layers:
    layer.trainable=False
    for w in layer.weights: w._trainable=False
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
for _layer in model:
    if _layer.name == 'distilbert':
        print(f"Freezing model layer {_layer.name}")
        _layer.trainable = False
    print(_layer.name)
    print(_layer.trainable)
---
Freezing model layer distilbert
distilbert
False      <----------------
pre_classifier
True
classifier
True
dropout_99
True

Model: "tf_distil_bert_for_sequence_classification_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
distilbert (TFDistilBertMain multiple                  66362880  
_________________________________________________________________
pre_classifier (Dense)       multiple                  590592    
_________________________________________________________________
classifier (Dense)           multiple                  1538      
_________________________________________________________________
dropout_99 (Dropout)         multiple                  0         
=================================================================
Total params: 66,955,010
Trainable params: 592,130
Non-trainable params: 66,362,880   <-----

没有冻结。

Model: "tf_distil_bert_for_sequence_classification_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
distilbert (TFDistilBertMain multiple                  66362880  
_________________________________________________________________
pre_classifier (Dense)       multiple                  590592    
_________________________________________________________________
classifier (Dense)           multiple                  1538      
_________________________________________________________________
dropout_59 (Dropout)         multiple                  0         
=================================================================
Total params: 66,955,010
Trainable params: 66,955,010
Non-trainable params: 0

请相应地将 TFDistilBertForSequenceClassification 更改为 TFBertForSequenceClassification。为此,首先 运行 model.summary 验证基本名称。对于 TFDistilBertForSequenceClassification,它是 distilbert.