tf2.keras 如何在微调中冻结某些 BERT 层
How to freeze some layers of BERT in fine tuning in tf2.keras
我正在尝试微调 'bert-based-uncased' 文本分类任务的数据集。这是我下载模型的方式:
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
由于 bert-base 有 12 层,我只想微调最后 2 层以防止过度拟合。 model.layers[i].trainable = False
无济于事。因为 model.layers[0]
给出了整个 bert 基础模型,如果我将 trainable
参数设置为 False
,那么所有的 bert 层都将被冻结。这是model
的架构:
Model: "tf_bert_for_sequence_classification"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bert (TFBertMainLayer) multiple 109482240
dropout_37 (Dropout) multiple 0
classifier (Dense) multiple 9997
=================================================================
Total params: 109,492,237
Trainable params: 109,492,237
Non-trainable params: 0
_________________________________________________________________
另外,我想用model.layers[0].weights[j]._trainable = False
;但是 weights
列表有 199 个 TensorShape([30522, 768])
形状的元素。所以我无法弄清楚哪些权重与最后两层有关。
谁能帮我解决这个问题?
我找到了答案并在这里分享。希望它可以帮助别人。
借助this article,这是关于使用pytorch微调bert,tensorflow2.keras中的等价物如下:
model.bert.encoder.layer[i].trainable = False
其中 i 是适当图层的索引。
我正在尝试微调 'bert-based-uncased' 文本分类任务的数据集。这是我下载模型的方式:
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
由于 bert-base 有 12 层,我只想微调最后 2 层以防止过度拟合。 model.layers[i].trainable = False
无济于事。因为 model.layers[0]
给出了整个 bert 基础模型,如果我将 trainable
参数设置为 False
,那么所有的 bert 层都将被冻结。这是model
的架构:
Model: "tf_bert_for_sequence_classification"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bert (TFBertMainLayer) multiple 109482240
dropout_37 (Dropout) multiple 0
classifier (Dense) multiple 9997
=================================================================
Total params: 109,492,237
Trainable params: 109,492,237
Non-trainable params: 0
_________________________________________________________________
另外,我想用model.layers[0].weights[j]._trainable = False
;但是 weights
列表有 199 个 TensorShape([30522, 768])
形状的元素。所以我无法弄清楚哪些权重与最后两层有关。
谁能帮我解决这个问题?
我找到了答案并在这里分享。希望它可以帮助别人。 借助this article,这是关于使用pytorch微调bert,tensorflow2.keras中的等价物如下:
model.bert.encoder.layer[i].trainable = False
其中 i 是适当图层的索引。