无法在 Pytorch Lightning 中加载自定义预训练权重
Unable to load custom pretrained weight in Pytorch Lightning
我想用我的小数据集重新训练自定义模型。我可以在 Pytorch 中加载预训练权重 (.pth) 和 运行。但是,我需要更多功能并将代码重构为 Pytorch 闪电,但我无法弄清楚如何将预训练权重加载到 Pytorch 闪电模型中。
请在下面查看我的代码的详细信息:
class BDRAR(nn.Module):
def __init__(self):
super(BDRAR, self).__init__()
resnext = ResNeXt101()
self.layer0 = resnext.layer0
self.layer1 = resnext.layer1
self.layer2 = resnext.layer2
self.layer3 = resnext.layer3
self.layer4 = resnext.layer4
Pytorch闪电代码:
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def forward(self, x):
return self.model(x)
火炬闪电 运行:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR.load_from_checkpoint(path, strict=False)
trainer = pl.Trainer(fast_dev_run=True, gpus=1)
trainer.fit(bdrar)
错误:
keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
**KeyError: 'state_dict'**
我将不胜感激。
谢谢。
可能是您的 .pth
文件已经是 state_dict
。尝试在闪电中加载预训练重量 class。
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def load_model(self, path):
self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)
path = './ckpt/BDRAR/3000.pth'
model = liteBDRAR()
model.load_model(path)
您收到此错误的原因是您试图将 PyTorch 的模型权重加载到 Lightning 模块中。使用 Lightning 保存检查点时,您不仅可以保存模型状态,还可以保存大量其他信息(参见 here)。
您要查找的是以下内容:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))
那些预训练的权重属于class BDRAR(nn.Module)
。也就是说,闪电模块 model
参数中的 class。
LightningModule liteBDRAR()
充当 Pytorch 模型(位于 self.model
)的包装器。您需要将权重加载到闪电模块内的 pytorch 模型上。
正如@Jules 和@Dharman 提到的,你需要的是:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))
我想用我的小数据集重新训练自定义模型。我可以在 Pytorch 中加载预训练权重 (.pth) 和 运行。但是,我需要更多功能并将代码重构为 Pytorch 闪电,但我无法弄清楚如何将预训练权重加载到 Pytorch 闪电模型中。
请在下面查看我的代码的详细信息:
class BDRAR(nn.Module):
def __init__(self):
super(BDRAR, self).__init__()
resnext = ResNeXt101()
self.layer0 = resnext.layer0
self.layer1 = resnext.layer1
self.layer2 = resnext.layer2
self.layer3 = resnext.layer3
self.layer4 = resnext.layer4
Pytorch闪电代码:
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def forward(self, x):
return self.model(x)
火炬闪电 运行:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR.load_from_checkpoint(path, strict=False)
trainer = pl.Trainer(fast_dev_run=True, gpus=1)
trainer.fit(bdrar)
错误:
keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
**KeyError: 'state_dict'**
我将不胜感激。
谢谢。
可能是您的 .pth
文件已经是 state_dict
。尝试在闪电中加载预训练重量 class。
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def load_model(self, path):
self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)
path = './ckpt/BDRAR/3000.pth'
model = liteBDRAR()
model.load_model(path)
您收到此错误的原因是您试图将 PyTorch 的模型权重加载到 Lightning 模块中。使用 Lightning 保存检查点时,您不仅可以保存模型状态,还可以保存大量其他信息(参见 here)。
您要查找的是以下内容:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))
那些预训练的权重属于class BDRAR(nn.Module)
。也就是说,闪电模块 model
参数中的 class。
LightningModule liteBDRAR()
充当 Pytorch 模型(位于 self.model
)的包装器。您需要将权重加载到闪电模块内的 pytorch 模型上。
正如@Jules 和@Dharman 提到的,你需要的是:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))