如何在 pytorch 中加载经过修改的 vgg19 网络中的预训练权重?

How to load pretrained weights in modified vgg19 network in pytorch?

我正在尝试使用修改后的输入通道数加载 vgg19 网络。输入通道的数量是 4 是我的情况,而且我正在将分类器更改为我自己的分类器。我还从网络中删除了自适应平均池化层。我应该如何将预训练的权重加载到 PyTorch 模型的修改版本中?

假设我的模型的修改版本在变量 myModel 中。如何将 vgg19 的预训练权重加载到相同的权重中?

选项 1. 如果要使用原始 VGG19 网络给出的原始预训练权重,则必须先加载权重,然后再修改网络。 预训练的权重是为原始网络定义的,因此需要匹配输入通道。 然后你可以在开头添加一个额外的层作为输入层,并在你的新网络中删除池化层。

选项2。您可以单独加载除输入层以外的所有层的权重,因为这会存在维度不匹配。

在代码中它看起来像这样 -

  # corresp_name is a dict object with mapping for your given layer 
  # name and original models layer name
  p_dict = torch.load(Path.model_dir()) #p_dict is my_model
  s_dict = self.state_dict()
  for name in p_dict:
      if name not in corresp_name:
            continue
      s_dict[corresp_name[name]] = p_dict[name]
  self.load_state_dict(s_dict)