无法在 Pytorch Lightning 中加载自定义预训练权重

Unable to load custom pretrained weight in Pytorch Lightning

我想用我的小数据集重新训练自定义模型。我可以在 Pytorch 中加载预训练权重 (.pth) 和 运行。但是,我需要更多功能并将代码重构为 Pytorch 闪电,但我无法弄清楚如何将预训练权重加载到 Pytorch 闪电模型中。

请在下面查看我的代码的详细信息:

class BDRAR(nn.Module):
    def __init__(self):
        super(BDRAR, self).__init__()
        resnext = ResNeXt101()
        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

Pytorch闪电代码:

class liteBDRAR(pl.LightningModule):
    def __init__(self):
        super(liteBDRAR, self).__init__()
        self.model = BDRAR()
        print('Model Created!')

    def forward(self, x):
        return self.model(x)

火炬闪电 运行:

    path = './ckpt/BDRAR/3000.pth'
    bdrar = liteBDRAR.load_from_checkpoint(path,  strict=False)
    trainer = pl.Trainer(fast_dev_run=True, gpus=1)
    trainer.fit(bdrar)

错误:

keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
**KeyError: 'state_dict'**

我将不胜感激。

谢谢。

可能是您的 .pth 文件已经是 state_dict。尝试在闪电中加载预训练重量 class。

class liteBDRAR(pl.LightningModule):
    def __init__(self):
        super(liteBDRAR, self).__init__()
        self.model = BDRAR()
        print('Model Created!')

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)

path = './ckpt/BDRAR/3000.pth'
model = liteBDRAR()
model.load_model(path)

您收到此错误的原因是您试图将 PyTorch 的模型权重加载到 Lightning 模块中。使用 Lightning 保存检查点时,您不仅可以保存模型状态,还可以保存大量其他信息(参见 here)。

您要查找的是以下内容:

path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))

那些预训练的权重属于class BDRAR(nn.Module)。也就是说,闪电模块 model 参数中的 class。

LightningModule liteBDRAR() 充当 Pytorch 模型(位于 self.model)的包装器。您需要将权重加载到闪电模块内的 pytorch 模型上。 正如@Jules 和@Dharman 提到的,你需要的是:

path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))