我可以使用 pytorch resnet18 模型预测一个图像但不能预测一组图像,我如何使用 pytorch 模型预测列表中的一组图像?
I can predict one image but not a set of images with a pytorch resnet18 model, how can i predict a set of images in a list using pytorch models?
x 是 (36, 60, 3) 张图像的列表。我正在尝试使用 pytorch 预训练的 resnet18 来预测图像的输出。我将 x 作为 2 张图像的列表。当我只拍摄 1 张图像时,我得到的预测没有错误,如下所示:
im = x[0]
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
# Pass the image for preprocessing and the image preprocessed
img_preprocessed = preprocess(im)
# Reshape, crop, and normalize the input tensor for feeding into network for evaluation
batch_img_tensor = torch.unsqueeze(img_preprocessed, 0)
resnet18.eval()
out = resnet18(batch_img_tensor).flatten()
但是当我设置 im=x 时它不起作用。预处理线出现问题,我收到此错误:
TypeError: pic should be PIL Image or ndarray. Got <class 'list'>
我按如下方式尝试了变量 (torch.tensot(x)) :
x=dataset(source_p)
y=Variable(torch.tensor(x))
print(y.shape)
resnet18(y)
我收到以下错误:
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[2, 36, 60, 3] to have 3 channels, but got 36 channels instead
我的问题是:如何一次预测 x 列表中的所有图像?
谢谢!
您需要沿第 0 个维度对图像进行批处理。
im = torch.stack(x, 0)
最终我创建了一个 class 接受 x 并转换所有元素 :
class formDataset(Dataset):
def __init__(self, imgs, transform=None):
self.imgs = imgs
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.imgs[idx]
sample = {image}
if self.transform:
sample = self.transform(sample)
return sample
我打电话后
l_set=formDataset(imgs=x,transform=preprocess)
l_loader = DataLoader(l_set, batch_size=2)
for data in (l_loader):
features=resnet(outputs)
x 是 (36, 60, 3) 张图像的列表。我正在尝试使用 pytorch 预训练的 resnet18 来预测图像的输出。我将 x 作为 2 张图像的列表。当我只拍摄 1 张图像时,我得到的预测没有错误,如下所示:
im = x[0]
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
# Pass the image for preprocessing and the image preprocessed
img_preprocessed = preprocess(im)
# Reshape, crop, and normalize the input tensor for feeding into network for evaluation
batch_img_tensor = torch.unsqueeze(img_preprocessed, 0)
resnet18.eval()
out = resnet18(batch_img_tensor).flatten()
但是当我设置 im=x 时它不起作用。预处理线出现问题,我收到此错误:
TypeError: pic should be PIL Image or ndarray. Got <class 'list'>
我按如下方式尝试了变量 (torch.tensot(x)) :
x=dataset(source_p)
y=Variable(torch.tensor(x))
print(y.shape)
resnet18(y)
我收到以下错误:
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[2, 36, 60, 3] to have 3 channels, but got 36 channels instead
我的问题是:如何一次预测 x 列表中的所有图像?
谢谢!
您需要沿第 0 个维度对图像进行批处理。
im = torch.stack(x, 0)
最终我创建了一个 class 接受 x 并转换所有元素 :
class formDataset(Dataset):
def __init__(self, imgs, transform=None):
self.imgs = imgs
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.imgs[idx]
sample = {image}
if self.transform:
sample = self.transform(sample)
return sample
我打电话后
l_set=formDataset(imgs=x,transform=preprocess)
l_loader = DataLoader(l_set, batch_size=2)
for data in (l_loader):
features=resnet(outputs)