如何在 Colab 的 PyTorch 函数中加载定义的已保存模型?

How to load a saved model defined in a function in PyTorch in Colab?

这是我的训练函数的示例代码(删除了不必要的部分):

我试图在 torch.save() 中保存我的模型 data_gen,在 运行 之后 train_dmc功能,我可以在目录中找到检查点文件。

def train_dmc(loader,loss):


 
  data_gen = DataGenerator().to(device)

  data_gen_optimizer = optim.Rprop(para_list, lr=lrate)


  savepath='/content/drive/MyDrive/'+loss+'checkpoint.t7'
  state = {
            'epoch': epoch,
            'model_state_dict': data_gen.state_dict(),
            'optimizer_state_dict': data_gen_optimizer.state_dict(),
            'data loss': data_loss,
            'latent_loss':latent_loss
            }
  torch.save(state,savepath)

我的问题是,如果 Google Colab 断开连接,如何加载检查点文件以继续训练。

应该加载data_gen还是train_dmc(),我是第一次用这个和我真的很困惑,因为 data_gen 是在另一个函数中定义的。希望有人能帮我解释一下

data_gen.load_state_dict(torch.load(PATH))
data_gen.eval()

#or

train_dmc.load_state_dict(torch.load(PATH))
train_dmc.eval()

由于state变量是一个字典,所以尝试将其保存为:

with open('/content/checkpoint.t7', 'wb') as handle:
    pickle.dump(state, handle, protocol=pickle.HIGHEST_PROTOCOL)

启动您的模型 class 作为 data_gen = DataGenerator().to(device)

并将检查点文件加载为:

import pickle
file = open('/content/checkpoint.t7', 'rb')
loaded_state = pickle.load(file)

然后您可以使用 data_gen = loaded_state['model_state_dict'] 加载 state_dict。这会将 state_dict 加载到模型 class!