为什么我们在 Huggingface Transformers 的 BERT 预训练模型中需要 init_weight 函数?

Why we need the init_weight function in BERT pretrained model in Huggingface Transformers?

在Hugginface transformer的代码中,有很多微调模型具有init_weight的功能。 比如(here),最后有一个init_weight函数

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

据我所知,它将调用以下内容 code

def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    elif isinstance(module, BertLayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

我的问题是如果我们加载的是预训练模型,为什么我们需要为每个模块初始化权重?

我想我一定是误解了什么。

BertPreTrainedModel 是一个抽象 class 如果你检查,错误是 BertPreTrainedModel class 没有构造函数甚至认为它被调用你可能用 PR 修饰此代码,但请确保先创建问题。

查看 .from_pretrained() 的代码。实际发生的事情是这样的:

  • 找到正确的基础模型class进行初始化
  • 使用伪随机初始化来初始化 class(通过使用您提到的 _init_weights 函数)
  • 找到具有预训练权重的文件
  • 在适用的地方用预训练的权重覆盖我们刚刚创建的模型的权重

这确保未预训练的层(例如在某些情况下最终的 class化层)do_init_weights 中初始化但不不会被覆盖。