使用预训练的 pytorch vgg16 模型及其 类 进行分类

Classification with pretrained pytorch vgg16 model and its classes

我用pytorch的预训练vgg16模型写了一个图像vgg分类模型

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
import urllib
from skimage.transform import resize
from skimage import io
import yaml

# Downloading imagenet 1000 classes list
file = urllib. request. urlopen("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
classes = ''
for f in file:
  classes = classes +  f.decode("utf-8")
classes = yaml.load(classes)

# Downloading pretrained vgg16 model
model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True)

print(model)

for param in model.parameters():
    param.requires_grad = False


url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/dog.jpg", "dog.jpg")

image=io.imread(url)

plt.imshow(image)
plt.show()

# resize to 224x224x3
img = resize(image,(224,224,3))

plt.imshow(img)
plt.show()
# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = mean*img+std
img1 = np.clip(img1,0,1)

img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0,3,2,1) # batch_size x channels x height x width

model.eval()
pred = model(img1.float())
print(classes[torch.argmax(pred).numpy().tolist()])

代码工作正常但输出错误类。我不确定我哪里做错了,但如果我不得不猜测它可能是 imagenet yaml 类 列表或规范化输入图像。谁能告诉我哪里出错了?

图像预处理存在一些问题。首先,归一化计算为(value - mean) / std),而不是value * mean + std。其次,不应将值限制为 [0, 1],归一化有意将值从 [0, 1] 移开。其次,作为 NumPy 数组的图像具有形状 [height, width, 3],当您置换维度时,您交换高度和宽度维度,创建一个形状为 的张量[batch_size、通道、宽度、高度].

img = resize(image,(224,224,3))


# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = (img1 - mean) / std

img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0, 3, 1, 2) # batch_size x channels x height x width

您可以使用 torchvision.transforms.

而不是手动执行此操作
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img = resize(image,(224,224,3))
img1 = preprocess(img)
img1 = img1.unsqueeze(0)

如果您使用 PIL 加载图片,您还可以通过添加 transforms.Resize((224, 224)) to the preprocessing pipeline, or you could even add transforms.ToPILImage() 来调整图片大小,首先将图片转换为 PIL 图片(transforms.Resize 需要 PIL 图片)。