PyTorch 中的图像特征提取

Image Feature Extraction in PyTorch

我很难理解这段代码。

import torch
import torch.nn as nn
import torchvision.models as models

def ResNet152(out_features = 10):
      return getattr(models, "resnet152")(pretrained=False, num_classes = out_features)

def VGG(out_features = 10):
      return getattr(models, "vgg19")(pretrained=False, num_classes = out_features)

在此代码段中,输入图像的特征由 ResNet152 和 Vgg19 模型提取。但是我有一个问题,是否从这些模型的哪一部分提取特征,无论该部分是最后一个池化层还是分类层之前的层还是其他。

请注意 getattr(models, 'resnet152') 等同于 models.resent152

因此,下面的代码返回模型本身。

getattr(models, "resnet152")(pretrained=False, num_classes = out_features)
# is same as
models.resnet152(pretrained=False, num_classes = out_features)

现在,如果您通过简单打印来查看模型的结构,最后一层是 fully-connected 层,所以这就是您在这里获得的特征。

print(ResNet152())

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
...
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=10, bias=True)
)

VGG()也是如此。