从 PyTorch 闪电模型中检索 PyTorch 模型
Retrieve the PyTorch model from a PyTorch lightning model
我训练了一个 PyTorch 闪电模型,如下所示:
In [16]: MLP
Out[16]:
DecoderMLP(
(loss): RMSE()
(logging_metrics): ModuleList(
(0): SMAPE()
(1): MAE()
(2): RMSE()
(3): MAPE()
(4): MASE()
)
(input_embeddings): MultiEmbedding(
(embeddings): ModuleDict(
(LCLid): Embedding(5, 4)
(sun): Embedding(5, 4)
(day_of_week): Embedding(7, 5)
(month): Embedding(12, 6)
(year): Embedding(3, 3)
(holidays): Embedding(2, 1)
(BusinessDay): Embedding(2, 1)
(day): Embedding(31, 11)
(hour): Embedding(24, 9)
)
)
(mlp): FullyConnectedModule(
(sequential): Sequential(
(0): Linear(in_features=60, out_features=435, bias=True)
(1): ReLU()
(2): Dropout(p=0.13371112461182535, inplace=False)
(3): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(4): Linear(in_features=435, out_features=435, bias=True)
(5): ReLU()
(6): Dropout(p=0.13371112461182535, inplace=False)
(7): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(8): Linear(in_features=435, out_features=435, bias=True)
(9): ReLU()
(10): Dropout(p=0.13371112461182535, inplace=False)
(11): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(12): Linear(in_features=435, out_features=435, bias=True)
(13): ReLU()
(14): Dropout(p=0.13371112461182535, inplace=False)
(15): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(16): Linear(in_features=435, out_features=435, bias=True)
(17): ReLU()
(18): Dropout(p=0.13371112461182535, inplace=False)
(19): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(20): Linear(in_features=435, out_features=435, bias=True)
(21): ReLU()
(22): Dropout(p=0.13371112461182535, inplace=False)
(23): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(24): Linear(in_features=435, out_features=435, bias=True)
(25): ReLU()
(26): Dropout(p=0.13371112461182535, inplace=False)
(27): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(28): Linear(in_features=435, out_features=435, bias=True)
(29): ReLU()
(30): Dropout(p=0.13371112461182535, inplace=False)
(31): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(32): Linear(in_features=435, out_features=435, bias=True)
(33): ReLU()
(34): Dropout(p=0.13371112461182535, inplace=False)
(35): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(36): Linear(in_features=435, out_features=1, bias=True)
)
)
)
我需要相应的 PyTorch 模型用于我的其他应用程序之一。
有简单的方法吗?
我想到了保存检查点,但后来我不知道该怎么做。
你能帮忙吗?
谢谢
您可以在 LightningModule
中手动保存 torch.nn.Module
的权重。类似于:
trainer.fit(model, trainloader, valloader)
torch.save(
model.input_embeddings.state_dict(),
"input_embeddings.pt"
)
torch.save(model.mlp.state_dict(), "mlp.pt")
然后无需 Lightning 即可加载:
# create the "blank" networks like they
# were created in the Lightning Module
input_embeddings = MultiEmbedding(...)
mlp = FullyConnectedModule(...)
# Load the models for inference
input_embeddings.load_state_dict(
torch.load("input_embeddings.pt")
)
input_embeddings.eval()
mlp.load_state_dict(
torch.load("mlp.pt")
)
mlp.eval()
有关保存和加载 PyTorch 模块的更多信息,请参阅 PyTorch 文档中的 Saving and Loading Models: Saving & Loading Model for Inference。
由于 Lightning 会自动将检查点保存到磁盘(如果使用默认的 Tensorboard 记录器,请检查 lightning_logs
文件夹),您还可以加载预训练的 LightningModule
然后保存状态指令,而无需重复所有的训练。不要在前面的代码中调用 trainer.fit
,而是尝试
model = DecoderMLP.load_from_checkpoint("path/to/checkpoint.ckpt")
我训练了一个 PyTorch 闪电模型,如下所示:
In [16]: MLP
Out[16]:
DecoderMLP(
(loss): RMSE()
(logging_metrics): ModuleList(
(0): SMAPE()
(1): MAE()
(2): RMSE()
(3): MAPE()
(4): MASE()
)
(input_embeddings): MultiEmbedding(
(embeddings): ModuleDict(
(LCLid): Embedding(5, 4)
(sun): Embedding(5, 4)
(day_of_week): Embedding(7, 5)
(month): Embedding(12, 6)
(year): Embedding(3, 3)
(holidays): Embedding(2, 1)
(BusinessDay): Embedding(2, 1)
(day): Embedding(31, 11)
(hour): Embedding(24, 9)
)
)
(mlp): FullyConnectedModule(
(sequential): Sequential(
(0): Linear(in_features=60, out_features=435, bias=True)
(1): ReLU()
(2): Dropout(p=0.13371112461182535, inplace=False)
(3): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(4): Linear(in_features=435, out_features=435, bias=True)
(5): ReLU()
(6): Dropout(p=0.13371112461182535, inplace=False)
(7): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(8): Linear(in_features=435, out_features=435, bias=True)
(9): ReLU()
(10): Dropout(p=0.13371112461182535, inplace=False)
(11): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(12): Linear(in_features=435, out_features=435, bias=True)
(13): ReLU()
(14): Dropout(p=0.13371112461182535, inplace=False)
(15): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(16): Linear(in_features=435, out_features=435, bias=True)
(17): ReLU()
(18): Dropout(p=0.13371112461182535, inplace=False)
(19): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(20): Linear(in_features=435, out_features=435, bias=True)
(21): ReLU()
(22): Dropout(p=0.13371112461182535, inplace=False)
(23): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(24): Linear(in_features=435, out_features=435, bias=True)
(25): ReLU()
(26): Dropout(p=0.13371112461182535, inplace=False)
(27): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(28): Linear(in_features=435, out_features=435, bias=True)
(29): ReLU()
(30): Dropout(p=0.13371112461182535, inplace=False)
(31): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(32): Linear(in_features=435, out_features=435, bias=True)
(33): ReLU()
(34): Dropout(p=0.13371112461182535, inplace=False)
(35): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
(36): Linear(in_features=435, out_features=1, bias=True)
)
)
)
我需要相应的 PyTorch 模型用于我的其他应用程序之一。
有简单的方法吗?
我想到了保存检查点,但后来我不知道该怎么做。
你能帮忙吗? 谢谢
您可以在 LightningModule
中手动保存 torch.nn.Module
的权重。类似于:
trainer.fit(model, trainloader, valloader)
torch.save(
model.input_embeddings.state_dict(),
"input_embeddings.pt"
)
torch.save(model.mlp.state_dict(), "mlp.pt")
然后无需 Lightning 即可加载:
# create the "blank" networks like they
# were created in the Lightning Module
input_embeddings = MultiEmbedding(...)
mlp = FullyConnectedModule(...)
# Load the models for inference
input_embeddings.load_state_dict(
torch.load("input_embeddings.pt")
)
input_embeddings.eval()
mlp.load_state_dict(
torch.load("mlp.pt")
)
mlp.eval()
有关保存和加载 PyTorch 模块的更多信息,请参阅 PyTorch 文档中的 Saving and Loading Models: Saving & Loading Model for Inference。
由于 Lightning 会自动将检查点保存到磁盘(如果使用默认的 Tensorboard 记录器,请检查 lightning_logs
文件夹),您还可以加载预训练的 LightningModule
然后保存状态指令,而无需重复所有的训练。不要在前面的代码中调用 trainer.fit
,而是尝试
model = DecoderMLP.load_from_checkpoint("path/to/checkpoint.ckpt")