我可以使用 pytorch resnet18 模型预测一个图像但不能预测一组图像,我如何使用 pytorch 模型预测列表中的一组图像?

I can predict one image but not a set of images with a pytorch resnet18 model, how can i predict a set of images in a list using pytorch models?

x 是 (36, 60, 3) 张图像的列表。我正在尝试使用 pytorch 预训练的 resnet18 来预测图像的输出。我将 x 作为 2 张图像的列表。当我只拍摄 1 张图像时,我得到的预测没有错误,如下所示:

im = x[0]
preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
# Pass the image for preprocessing and the image preprocessed
img_preprocessed = preprocess(im)
# Reshape, crop, and normalize the input tensor for feeding into network for evaluation
batch_img_tensor = torch.unsqueeze(img_preprocessed, 0)
resnet18.eval()
out = resnet18(batch_img_tensor).flatten()

但是当我设置 im=x 时它不起作用。预处理线出现问题,我收到此错误:

TypeError: pic should be PIL Image or ndarray. Got <class 'list'>

我按如下方式尝试了变量 (torch.tensot(x)) :

x=dataset(source_p)
y=Variable(torch.tensor(x))
print(y.shape)
resnet18(y)

我收到以下错误:

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[2, 36, 60, 3] to have 3 channels, but got 36 channels instead

我的问题是:如何一次预测 x 列表中的所有图像

谢谢!

您需要沿第 0 个维度对图像进行批处理。

im = torch.stack(x, 0)

最终我创建了一个 class 接受 x 并转换所有元素 :

class formDataset(Dataset):

    def __init__(self, imgs, transform=None):
       
        self.imgs = imgs
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = self.imgs[idx] 
        sample = {image}

        if self.transform:
            sample = self.transform(sample)

        return sample

我打电话后

l_set=formDataset(imgs=x,transform=preprocess)
l_loader = DataLoader(l_set, batch_size=2)

for data in (l_loader):
     features=resnet(outputs)