pic 应该是 Tensor 或 ndarray。得到 <class ‘NoneType’>

pic should be Tensor or ndarray. Got <class ‘NoneType’>

我是 PyTorch 的初学者。我想使用 NYU 数据集训练网络,但出现错误。

我使用Dataloader加载本地数据集时出现错误,我想打印数据以证明代码正确:

test=Mydataset(data_root,transforms,'image_train')
test2=DataLoader(test,batch_size=4,num_workers=0,shuffle=False)
for idx,data in enumerate(test2):
  print(idx)

下面是带有 Mydataset 定义的其余代码:

from __future__ import division,absolute_import,print_function
from PIL import Image
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
data_root='D:/AuxiliaryDocuments/NYU/'
transforms=transforms.Compose([transforms.ToPILImage(),
                           transforms.Resize(224,101),
                           transforms.ToTensor()])

filename_txt={'image_train':'image_train.txt','image_test':'image_test.txt',
          'depth_train':'depth_train.txt','depth_test':'depth_test.txt'}


class Mydataset(Dataset):
  def __init__(self,data_root,transformation,data_type):
    self.transform=transformation
    self.image_path_txt=filename_txt[data_type]
    self.sample_list=list()
    f=open(data_root+'/'+data_type+'/'+self.image_path_txt)
    lines=f.readlines()
    for line in lines:
        line=line.strip()
        line=line.replace(';','')
        self.sample_list.append(line)
    f.close()

def __getitem__(self, index):
    item=self.sample_list[index]
    img=Image.open(item)
    if self.transform is not None:
        img=self.transform(img)
    idx=index
    return idx,img

def __len__(self):
    return len(self.sample_list)

标题中的错误与图片中的错误不同(顺便说一下,您应该将其作为文本发布)。假设图片中的那个是正确的,您的问题如下:

您的 transformstransforms.ToPILImage() 开头,但图像已被数据加载器读取为 PIL 图像。如果删除该转换,代码应该 运行 就好了。

# [...]
transforms = transforms.Compose([
    transforms.ToPILImage(),  # <<< remove this
    transforms.Resize(224, 101),
    transforms.ToTensor()
])

# [...]

class Mydataset(Dataset):
    # [...]
    def __getitem__(self, index):
        item = self.sample_list[index]
        img = Image.open(item)  # <<< this image is already a PIL image
        if self.transform is not None:
            img = self.transform(img)
        idx = index
        return idx, img
    # [...]