如何在未连接到互联网时缓存 Pytorch 模型以供使用?

How to cache Pytorch models for use when not connected to the internet?

我在分类问题中使用 vgg19。我可以访问校园研究计算机进行训练,但完成计算的节点无法访问互联网。所以 运行 一行代码如 self.net = models.vgg19(pretrained=True) 失败并出现错误 urllib.error.URLError: <urlopen error [Errno 101] Network is unreachable>

有没有一种方法可以在头节点(我可以访问互联网)上缓存模型,然后从缓存而不是计算节点上的互联网加载模型?

如果您只是将预训练网络的权重保存在某处,则可以像加载任何其他网络权重一样加载它们。

节省:

import torchvision

#  I am assuming we have internet access here
model = torchvision.models.vgg16(pretrained=True)
torch.save(model.state_dict(), "Somewhere")

正在加载:

import torchvision

def create_vgg16(dict_path=None):
    model = torchvision.models.vgg16(pretrained=False)
    if (dict_path != None):
        model.load_state_dict(torch.load(dict_path))
    return model

model = create_vgg16("Somewhere")