从 torchvision 预训练模型中获取模型 class 标签

Getting model class labels from torchvision pretrained models

我正在使用来自 torchvision 的预训练 Alexnet 模型(没有微调)。 问题是,尽管我能够 运行 某些数据的模型并获得输出概率分布,但我无法找到 class 标签将其映射到 .

关注这个official documentation

import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
model.eval()
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

在处理图像的一些步骤之后,我能够使用它来获得单个图像的输出作为 (1,1000) dim 向量,我将使用 softmax 来获得概率分布 -

#Output - 

tensor([-1.6531e+00, -4.3505e+00, -1.8172e+00, -4.2143e+00, -3.1914e+00,
         3.4163e-01,  1.0877e+00,  5.9350e+00,  8.0425e+00, -7.0242e-01,
        -9.4130e-01, -6.0822e-01, -2.4097e-01, -1.9946e+00, -1.5288e+00,
        -3.2656e+00, -5.5800e-01,  1.0524e+00,  1.9211e-01, -4.7202e+00,
        -3.3880e+00,  4.3048e+00, -1.0997e+00,  4.6132e+00, -5.7404e-03,
        -5.3437e+00, -4.7378e+00, -3.3974e+00, -4.1287e+00,  2.9064e-01,
        -3.2955e+00, -6.7051e+00, -4.7232e+00, -4.1778e+00, -2.1859e+00,
        -2.9469e+00,  3.0465e+00, -3.5882e+00, -6.3890e+00, -4.4203e+00,
        -3.3685e+00, -5.0983e+00, -4.9006e+00, -5.5235e+00, -3.7233e+00,
        -4.0204e+00,  2.6998e-01, -4.4702e+00, -5.6617e+00, -5.4880e+00,
        -2.6801e+00, -3.2129e+00, -1.6294e+00, -5.2289e+00, -2.7495e+00,
        -2.6286e+00, -1.8206e+00, -2.3196e+00, -5.2806e+00, -3.7652e+00,
        -3.0987e+00, -4.1421e+00, -5.2531e+00, -4.6505e+00, -3.5815e+00,
        -4.0189e+00, -4.0008e+00, -4.5512e+00, -3.2248e+00, -7.7903e+00,
        -1.4484e+00, -3.8347e+00, -4.5611e+00, -4.3681e+00,  2.7234e-01,
        -4.0162e+00, -4.2136e+00, -5.4524e+00,  1.1744e+00, -4.7785e+00,
        -1.8335e+00,  4.1288e-01,  2.2239e+00, -9.9919e-02,  4.8216e+00,
        -8.4304e-01,  5.6911e-01, -4.0484e+00, -3.3013e+00,  2.8698e+00,
        -1.1419e+00, -9.1690e-01, -2.9284e+00, -2.6097e+00, -1.8213e-01,
        -2.5429e+00, -2.1095e+00,  2.2419e+00, -1.6280e+00,  7.4458e+00,
         2.3184e+00, -5.7408e+00, -7.4332e-01, -5.4066e+00,  1.5177e+01,
        -4.4737e-02,  1.8237e+00, -3.7741e+00,  9.2271e-01, -4.3687e-01,
        -1.4003e+00, -4.3026e+00,  6.3782e-01, -1.0808e+00, -1.4173e+00,
         2.6194e+00, -3.8418e+00,  1.1598e+00, -2.6876e+00, -3.6103e+00,
        -4.9281e+00, -4.1411e+00, -3.3603e+00, -3.4296e+00, -1.4997e+00,
        -2.8381e+00, -1.2843e+00,  1.5745e+00, -1.7449e+00,  4.2903e-01,
         3.1234e-01, -2.8206e+00,  3.6688e-01, -2.1033e+00,  1.6481e+00,
         1.4222e+00, -2.7303e+00, -3.6292e+00,  1.2864e+00, -2.5541e+00,
        -2.9663e+00, -4.1575e+00, -3.1954e+00, -4.6487e-01,  1.8916e+00,
        -7.4721e-01,  4.5986e+00, -2.5443e+00, -6.2003e+00, -1.3215e+00,
        -2.6225e+00,  9.9639e+00,  9.7772e+00,  9.6715e+00,  9.0857e+00,...

我从哪里得到 class 标签?我找不到任何方法让我从模型对象中获取它。

遗憾的是,您无法直接从 torchvision 模型中获取 class 标签名称。然而,这些模型是在 ImageNet 数据集上训练的(因此有 1000 classes)。

据我所知,您必须从网络上获取 class 名称映射;没有办法把它从火炬上拿下来。以前,您可以使用 torchvision.datasets.ImageNet 直接下载 ImageNet,它有一个 built-in 标签到 class 名称转换器。现在下载 link 不是公开的,需要手动下载才能被 datasets.ImageNet 使用。

因此您可以简单地搜索 class 以在线标记 ImageNet 的映射,而不是下载数据或尝试使用 torch。 Try here for example.