pic 应该是 Tensor 或 ndarray。得到 <class ‘NoneType’>
pic should be Tensor or ndarray. Got <class ‘NoneType’>
我是 PyTorch 的初学者。我想使用 NYU 数据集训练网络,但出现错误。
我使用Dataloader加载本地数据集时出现错误,我想打印数据以证明代码正确:
test=Mydataset(data_root,transforms,'image_train')
test2=DataLoader(test,batch_size=4,num_workers=0,shuffle=False)
for idx,data in enumerate(test2):
print(idx)
下面是带有 Mydataset
定义的其余代码:
from __future__ import division,absolute_import,print_function
from PIL import Image
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
data_root='D:/AuxiliaryDocuments/NYU/'
transforms=transforms.Compose([transforms.ToPILImage(),
transforms.Resize(224,101),
transforms.ToTensor()])
filename_txt={'image_train':'image_train.txt','image_test':'image_test.txt',
'depth_train':'depth_train.txt','depth_test':'depth_test.txt'}
class Mydataset(Dataset):
def __init__(self,data_root,transformation,data_type):
self.transform=transformation
self.image_path_txt=filename_txt[data_type]
self.sample_list=list()
f=open(data_root+'/'+data_type+'/'+self.image_path_txt)
lines=f.readlines()
for line in lines:
line=line.strip()
line=line.replace(';','')
self.sample_list.append(line)
f.close()
def __getitem__(self, index):
item=self.sample_list[index]
img=Image.open(item)
if self.transform is not None:
img=self.transform(img)
idx=index
return idx,img
def __len__(self):
return len(self.sample_list)
标题中的错误与图片中的错误不同(顺便说一下,您应该将其作为文本发布)。假设图片中的那个是正确的,您的问题如下:
您的 transforms
以 transforms.ToPILImage()
开头,但图像已被数据加载器读取为 PIL 图像。如果删除该转换,代码应该 运行 就好了。
# [...]
transforms = transforms.Compose([
transforms.ToPILImage(), # <<< remove this
transforms.Resize(224, 101),
transforms.ToTensor()
])
# [...]
class Mydataset(Dataset):
# [...]
def __getitem__(self, index):
item = self.sample_list[index]
img = Image.open(item) # <<< this image is already a PIL image
if self.transform is not None:
img = self.transform(img)
idx = index
return idx, img
# [...]
我是 PyTorch 的初学者。我想使用 NYU 数据集训练网络,但出现错误。
我使用Dataloader加载本地数据集时出现错误,我想打印数据以证明代码正确:
test=Mydataset(data_root,transforms,'image_train')
test2=DataLoader(test,batch_size=4,num_workers=0,shuffle=False)
for idx,data in enumerate(test2):
print(idx)
下面是带有 Mydataset
定义的其余代码:
from __future__ import division,absolute_import,print_function
from PIL import Image
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
data_root='D:/AuxiliaryDocuments/NYU/'
transforms=transforms.Compose([transforms.ToPILImage(),
transforms.Resize(224,101),
transforms.ToTensor()])
filename_txt={'image_train':'image_train.txt','image_test':'image_test.txt',
'depth_train':'depth_train.txt','depth_test':'depth_test.txt'}
class Mydataset(Dataset):
def __init__(self,data_root,transformation,data_type):
self.transform=transformation
self.image_path_txt=filename_txt[data_type]
self.sample_list=list()
f=open(data_root+'/'+data_type+'/'+self.image_path_txt)
lines=f.readlines()
for line in lines:
line=line.strip()
line=line.replace(';','')
self.sample_list.append(line)
f.close()
def __getitem__(self, index):
item=self.sample_list[index]
img=Image.open(item)
if self.transform is not None:
img=self.transform(img)
idx=index
return idx,img
def __len__(self):
return len(self.sample_list)
标题中的错误与图片中的错误不同(顺便说一下,您应该将其作为文本发布)。假设图片中的那个是正确的,您的问题如下:
您的 transforms
以 transforms.ToPILImage()
开头,但图像已被数据加载器读取为 PIL 图像。如果删除该转换,代码应该 运行 就好了。
# [...]
transforms = transforms.Compose([
transforms.ToPILImage(), # <<< remove this
transforms.Resize(224, 101),
transforms.ToTensor()
])
# [...]
class Mydataset(Dataset):
# [...]
def __getitem__(self, index):
item = self.sample_list[index]
img = Image.open(item) # <<< this image is already a PIL image
if self.transform is not None:
img = self.transform(img)
idx = index
return idx, img
# [...]