如何在未连接到互联网时缓存 Pytorch 模型以供使用?
How to cache Pytorch models for use when not connected to the internet?
我在分类问题中使用 vgg19。我可以访问校园研究计算机进行训练,但完成计算的节点无法访问互联网。所以 运行 一行代码如 self.net = models.vgg19(pretrained=True)
失败并出现错误 urllib.error.URLError: <urlopen error [Errno 101] Network is unreachable>
有没有一种方法可以在头节点(我可以访问互联网)上缓存模型,然后从缓存而不是计算节点上的互联网加载模型?
如果您只是将预训练网络的权重保存在某处,则可以像加载任何其他网络权重一样加载它们。
节省:
import torchvision
# I am assuming we have internet access here
model = torchvision.models.vgg16(pretrained=True)
torch.save(model.state_dict(), "Somewhere")
正在加载:
import torchvision
def create_vgg16(dict_path=None):
model = torchvision.models.vgg16(pretrained=False)
if (dict_path != None):
model.load_state_dict(torch.load(dict_path))
return model
model = create_vgg16("Somewhere")
我在分类问题中使用 vgg19。我可以访问校园研究计算机进行训练,但完成计算的节点无法访问互联网。所以 运行 一行代码如 self.net = models.vgg19(pretrained=True)
失败并出现错误 urllib.error.URLError: <urlopen error [Errno 101] Network is unreachable>
有没有一种方法可以在头节点(我可以访问互联网)上缓存模型,然后从缓存而不是计算节点上的互联网加载模型?
如果您只是将预训练网络的权重保存在某处,则可以像加载任何其他网络权重一样加载它们。
节省:
import torchvision
# I am assuming we have internet access here
model = torchvision.models.vgg16(pretrained=True)
torch.save(model.state_dict(), "Somewhere")
正在加载:
import torchvision
def create_vgg16(dict_path=None):
model = torchvision.models.vgg16(pretrained=False)
if (dict_path != None):
model.load_state_dict(torch.load(dict_path))
return model
model = create_vgg16("Somewhere")