有没有办法通过字符串加载 torchvision 模型?
Is there a way to load torchvision model by string?
目前,我使用以下代码加载预训练的 torchvision 模型:
import torchvision
torchvision.models.resnet101(pretrained=True)
但是,我希望将模型名称作为字符串参数,然后使用该字符串加载预训练模型。这样做的伪代码类似于:
model_name = 'resnet101'
torchvision.models.get(model_name)(pretrained=True)
有没有办法以相当简单的方式完成此操作?
您可以使用getattr
getattr(torchvision.models, 'resnet101')(pretrained=True)
您可以使用 torch.hub:
model_str = 'resnet50'
model = torch.hub.load('pytorch/vision', model_str, pretrained=True)
所有可用的字符串模型都可以通过以下方式找到:
torch.hub.list('pytorch/vision', force_reload=True)
输出:
['alexnet',
'deeplabv3_mobilenet_v3_large',
'deeplabv3_resnet101',
'deeplabv3_resnet50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'fcn_resnet101',
'fcn_resnet50',
'googlenet',
'inception_v3',
'lraspp_mobilenet_v3_large',
'mnasnet0_5',
'mnasnet0_75',
'mnasnet1_0',
'mnasnet1_3',
'mobilenet_v2',
'mobilenet_v3_large',
'mobilenet_v3_small',
'resnet101',
'resnet152',
'resnet18',
'resnet34',
'resnet50',
'resnext101_32x8d',
'resnext50_32x4d',
'shufflenet_v2_x0_5',
'shufflenet_v2_x1_0',
'squeezenet1_0',
'squeezenet1_1',
'vgg11',
'vgg11_bn',
'vgg13',
'vgg13_bn',
'vgg16',
'vgg16_bn',
'vgg19',
'vgg19_bn',
'wide_resnet101_2',
'wide_resnet50_2']
目前,我使用以下代码加载预训练的 torchvision 模型:
import torchvision
torchvision.models.resnet101(pretrained=True)
但是,我希望将模型名称作为字符串参数,然后使用该字符串加载预训练模型。这样做的伪代码类似于:
model_name = 'resnet101'
torchvision.models.get(model_name)(pretrained=True)
有没有办法以相当简单的方式完成此操作?
您可以使用getattr
getattr(torchvision.models, 'resnet101')(pretrained=True)
您可以使用 torch.hub:
model_str = 'resnet50'
model = torch.hub.load('pytorch/vision', model_str, pretrained=True)
所有可用的字符串模型都可以通过以下方式找到:
torch.hub.list('pytorch/vision', force_reload=True)
输出:
['alexnet',
'deeplabv3_mobilenet_v3_large',
'deeplabv3_resnet101',
'deeplabv3_resnet50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'fcn_resnet101',
'fcn_resnet50',
'googlenet',
'inception_v3',
'lraspp_mobilenet_v3_large',
'mnasnet0_5',
'mnasnet0_75',
'mnasnet1_0',
'mnasnet1_3',
'mobilenet_v2',
'mobilenet_v3_large',
'mobilenet_v3_small',
'resnet101',
'resnet152',
'resnet18',
'resnet34',
'resnet50',
'resnext101_32x8d',
'resnext50_32x4d',
'shufflenet_v2_x0_5',
'shufflenet_v2_x1_0',
'squeezenet1_0',
'squeezenet1_1',
'vgg11',
'vgg11_bn',
'vgg13',
'vgg13_bn',
'vgg16',
'vgg16_bn',
'vgg19',
'vgg19_bn',
'wide_resnet101_2',
'wide_resnet50_2']