有什么办法可以将 PyTorch 中可用的预训练模型下载到特定路径?
Is there any way I can download the pre-trained models available in PyTorch to a specific path?
我指的是可以在这里找到的模型:https://pytorch.org/docs/stable/torchvision/models.html#torchvision-models
TL;DR: 不,直接是不可能的,但你可以很容易地适应它。
我想你要做的是看一下torch.utils.model_zoo
,加载预训练模型时内部调用:
如果我们查看预训练模型的代码,例如 AlexNet here,我们可以看到它只是调用前面提到的 model_zoo
函数,但没有保存位置。您可以修改 PyTorch 源代码以指定它(这实际上是 IMO 的一个很好的补充,因此可以为此打开一个拉取请求),或者根据您自己的喜好简单地采用第二个 link 中的代码(并且将其保存到不同名称下的自定义位置),然后在那里手动插入相关位置。
如果您想定期更新 PyTorch,我会强烈推荐第二种方法,因为它不涉及直接更改 PyTorch 的代码库,并且可能在更新过程中抛出错误。
是的,您可以简单地复制网址并使用wget
将其下载到所需的路径。这是一个例子:
对于 AlexNet:
$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth
Google 盗梦空间 (v3):
$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
对于SqueezeNet:
$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth
对于MobileNetV2:
$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
对于DenseNet201:
$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth
对于MNASNet1_0:
$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth
对于ShuffleNetv2_x1.0:
$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
如果您想在 Python 中执行此操作,请使用类似的内容:
In [11]: from six.moves import urllib
# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"
# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)
P.S:您可以在torchvision.models
的各个python模块中找到下载网址
因为,@dennlinger mentioned in his : torch.utils.model_zoo
,在您加载预训练模型时被内部调用。
更具体地说,方法:torch.utils.model_zoo.load_url()
每次加载预训练模型时都会被调用。同样的文档提到:
The default value of model_dir
is $TORCH_HOME/models
where
$TORCH_HOME
defaults to ~/.torch
.
The default directory can be overridden with the $TORCH_HOME
environment variable.
这可以按如下方式完成:
import torch
import torchvision
import os
# Suppose you are trying to load pre-trained resnet model in directory- models\resnet
os.environ['TORCH_HOME'] = 'models\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)
我通过在 PyTorch 的 GitHub 存储库中提出问题来找到上述解决方案:
https://github.com/pytorch/vision/issues/616
这导致了文档的改进,即上述解决方案。
我指的是可以在这里找到的模型:https://pytorch.org/docs/stable/torchvision/models.html#torchvision-models
TL;DR: 不,直接是不可能的,但你可以很容易地适应它。
我想你要做的是看一下torch.utils.model_zoo
,加载预训练模型时内部调用:
如果我们查看预训练模型的代码,例如 AlexNet here,我们可以看到它只是调用前面提到的 model_zoo
函数,但没有保存位置。您可以修改 PyTorch 源代码以指定它(这实际上是 IMO 的一个很好的补充,因此可以为此打开一个拉取请求),或者根据您自己的喜好简单地采用第二个 link 中的代码(并且将其保存到不同名称下的自定义位置),然后在那里手动插入相关位置。
如果您想定期更新 PyTorch,我会强烈推荐第二种方法,因为它不涉及直接更改 PyTorch 的代码库,并且可能在更新过程中抛出错误。
是的,您可以简单地复制网址并使用wget
将其下载到所需的路径。这是一个例子:
对于 AlexNet:
$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth
Google 盗梦空间 (v3):
$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
对于SqueezeNet:
$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth
对于MobileNetV2:
$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
对于DenseNet201:
$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth
对于MNASNet1_0:
$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth
对于ShuffleNetv2_x1.0:
$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
如果您想在 Python 中执行此操作,请使用类似的内容:
In [11]: from six.moves import urllib
# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"
# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)
P.S:您可以在torchvision.models
的各个python模块中找到下载网址因为,@dennlinger mentioned in his torch.utils.model_zoo
,在您加载预训练模型时被内部调用。
更具体地说,方法:torch.utils.model_zoo.load_url()
每次加载预训练模型时都会被调用。同样的文档提到:
The default value of
model_dir
is$TORCH_HOME/models
where$TORCH_HOME
defaults to~/.torch
.The default directory can be overridden with the
$TORCH_HOME
environment variable.
这可以按如下方式完成:
import torch
import torchvision
import os
# Suppose you are trying to load pre-trained resnet model in directory- models\resnet
os.environ['TORCH_HOME'] = 'models\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)
我通过在 PyTorch 的 GitHub 存储库中提出问题来找到上述解决方案: https://github.com/pytorch/vision/issues/616
这导致了文档的改进,即上述解决方案。