PyTorch Lightning:在检查点文件中包含一些张量对象
PyTorch Lightning: includes some Tensor objects in checkpoint file
由于 Pytorch Lightning 提供了模型检查点的自动保存功能,我用它来保存 top-k 个最佳模型。具体在培训师设置中,
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='checkpoints/',
filename='{epoch:02d}-{val_acc:.2f}',
save_top_k=5,
mode='max',
)
这很好用,但它没有保存模型对象的某些属性。我的模型在每个训练时期结束时都会存储一些张量,这样
class SampleNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Linear(100, 1)
self.loss = torch.nn.CrossEntropy()
self.some_data = None # Initialize as None
def training_step(self, batch):
x, t = batch
out = self.layer(x)
loss = self.loss(out, t)
results = {'loss': loss}
return results
def training_epoch_end(self, outputs):
self.some_data = some_tensor_object
这是一个简化的示例,但我希望上面 checkpoint_callback
创建的检查点文件记住属性 self.some_data
但是当我从检查点加载模型时,它总是重置为 None
.训练时确认更新成功
我尝试不在 init
中将其初始化为 None,但加载模型时属性会消失。
我想避免将属性保存为不同的 pt
文件,因为它与模型配置相关联,因此我需要稍后手动将文件与相应的检查点文件匹配。
是否可以在检查点文件中包含这样的张量属性?
好像不能直接提取参数,最有可能用到的是nn.Module.state_dict()
。
此方法仅提取实际被视为参数的张量值。因此,在这种情况下,解决方法是将您的数据保存为参数(请参阅 docs):
self.some_data = torch.nn.parameter.Parameter(your_data)
只需使用模型 class 钩子 on_save_checkpoint()
和 on_load_checkpoint()
来获取您想要与默认属性一起保存的各种对象。
def on_save_checkpoint(self, checkpoint) -> None:
"Objects to include in checkpoint file"
checkpoint["some_data"] = self.some_data
def on_load_checkpoint(self, checkpoint) -> None:
"Objects to retrieve from checkpoint file"
self.some_data= checkpoint["some_data"]
由于 Pytorch Lightning 提供了模型检查点的自动保存功能,我用它来保存 top-k 个最佳模型。具体在培训师设置中,
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='checkpoints/',
filename='{epoch:02d}-{val_acc:.2f}',
save_top_k=5,
mode='max',
)
这很好用,但它没有保存模型对象的某些属性。我的模型在每个训练时期结束时都会存储一些张量,这样
class SampleNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Linear(100, 1)
self.loss = torch.nn.CrossEntropy()
self.some_data = None # Initialize as None
def training_step(self, batch):
x, t = batch
out = self.layer(x)
loss = self.loss(out, t)
results = {'loss': loss}
return results
def training_epoch_end(self, outputs):
self.some_data = some_tensor_object
这是一个简化的示例,但我希望上面 checkpoint_callback
创建的检查点文件记住属性 self.some_data
但是当我从检查点加载模型时,它总是重置为 None
.训练时确认更新成功
我尝试不在 init
中将其初始化为 None,但加载模型时属性会消失。
我想避免将属性保存为不同的 pt
文件,因为它与模型配置相关联,因此我需要稍后手动将文件与相应的检查点文件匹配。
是否可以在检查点文件中包含这样的张量属性?
好像不能直接提取参数,最有可能用到的是nn.Module.state_dict()
。
此方法仅提取实际被视为参数的张量值。因此,在这种情况下,解决方法是将您的数据保存为参数(请参阅 docs):
self.some_data = torch.nn.parameter.Parameter(your_data)
只需使用模型 class 钩子 on_save_checkpoint()
和 on_load_checkpoint()
来获取您想要与默认属性一起保存的各种对象。
def on_save_checkpoint(self, checkpoint) -> None:
"Objects to include in checkpoint file"
checkpoint["some_data"] = self.some_data
def on_load_checkpoint(self, checkpoint) -> None:
"Objects to retrieve from checkpoint file"
self.some_data= checkpoint["some_data"]