如何使用 torch.hub.load 加载本地模型?

How do I load a local model with torch.hub.load?

我需要避免从网上下载模型(由于安装机器的限制)。

这有效,但它是从 Internet 下载模型

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)

我已将 .pth 文件和 hubconf.py 文件放在 /tmp/ 文件夹中,并将我的代码更改为

model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')

但令我惊讶的是,它仍然从互联网上下载模型。我究竟做错了什么?如何在本地加载模型?

只是为了给你更多的细节,我在一个 Docker 容器中执行所有这些操作,该容器在运行时具有只读卷,这就是下载新文件失败的原因。

您可以采用两种方法在没有 Internet 连接的机器上获取可交付模型。

  1. 在普通机器上加载带有预训练模型的DeepLab,使用JIT编译器将其导出为图形,然后放入机器中。脚本很容易理解:

     # To export
     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
     traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
     traced_graph.save('DeepLab.pth')
    
     # To load
     model = torch.jit.load('DeepLab.pth').eval().to(device)
    

    在这种情况下,权重和网络结构被保存为计算图,因此您不需要任何额外的文件。

  2. 看看torchvision's GitHub repository.

    DeepLabV3 有一个 download URL 和 Resnet101 backbone 权重。

    您可以下载这些权重一次,然后使用带有 pretrained=False 标志的 torchvision 的 deeplab 并手动加载权重。

     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
     model.load_state_dict(torch.load('downloaded weights path'))
    

    考虑一下,state dict 中可能有一个 ['state_dict'] 或一些类似的父键,其中你会使用:

     model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
    
model_name='best.pt'
model = torch.hub.load(os.getcwd(), 'custom', source='local', path = model_name, force_reload = True)

这对我有用。默认来源是 github.