IndexError: list index out of range in prediction of images

IndexError: list index out of range in prediction of images

我正在对我写下所有 类' 名称的图像进行预测,在测试文件夹中,我有 20 张图像。请给我一些提示,为什么我会收到错误消息?我们如何检查模型的指标?

代码

import numpy as np
import sys, random
import torch
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import glob

# Paths for image directory and model
IMDIR = './test'
MODEL = 'checkpoint/resnet18/Monday_31_May_2021_21h_25m_05s/resnet18-1000-regular.pth'

# Load the model for testing
model = models.resnet18()

model.named_children()

torch.save(model.state_dict, MODEL)
model.eval()

# Class labels for prediction
class_names = ['BC', 'BK', 'CC', 'CL', 'CM', 'DF', 'DG', 'DS', 'HL', 'IF', 'JD', 'JS', 'LD', 'LP', 'LS', 'PO', 'RI',
               'SD', 'SG', 'TO']


# Retreive 9 random images from directory
files = Path(IMDIR).resolve().glob('*.*')
print(files)

images = random.sample(list(files), 1)
print(images)
# Configure plots
fig = plt.figure(figsize=(9, 9))
rows, cols = 3, 3

# Preprocessing transformations
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    # transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(0.5306, 0.1348)
])

# Enable gpu mode, if cuda available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Perform prediction and plot results
with torch.no_grad():
    for num, img in enumerate(images):
        img = Image.open(img).convert('RGB')
        inputs = preprocess(img).unsqueeze(0).cpu()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        print(preds)
        label = class_names[preds]
        plt.subplot(rows, cols, num + 1)
        plt.title("Pred: " + label)
        plt.axis('off')
        plt.imshow(img)
'''
Sample run: python test.py test
'''

回溯

Traceback (most recent call last):
  File "/media/khawar/HDD_Khawar/CVPR/pytorch-cifar100/test_box.py", line 57, in <module>
    label = class_names[preds]
IndexError: list index out of range

您的错误源于您没有对 resnet 模型的线性层进行任何修改。

我建议添加此代码:

# What you have
model = models.resnet18()

# What you need
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, len(class_names)))

这会将最后一个线性层更改为输出正确数量的节点

萨萨克