pytorch中加载数据时出现错误:'Image' object has no attribute 'shape'
error occurs when loading data in pytorch: 'Image' object has no attribute 'shape'
我正在使用基于ImageNet training in PyTorch的代码对resnet152进行微调,加载数据时出现错误,并且是在处理了几批图像后才出现的。我该如何解决这个问题。
以下代码是产生相同错误的简化代码:
代码
# Data loading code
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(train_img_dir, transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=256, shuffle=True,
num_workers=1, pin_memory=True)
for i, (input_x, target) in enumerate(train_loader):
if i % 10 == 0:
print(i)
print(input_x.shape)
print(target.shape)
错误
0
torch.Size([256, 3, 224, 224])
torch.Size([256])
10
torch.Size([256, 3, 224, 224])
torch.Size([256])
20
torch.Size([256, 3, 224, 224])
torch.Size([256])
30
torch.Size([256, 3, 224, 224])
torch.Size([256])
----------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-48-792d6ca206df> in <module>()
----> 1 for i, (input_x, target) in enumerate(train_loader):
2 if i % 10 == 0:
3 # sample_img = input_x[0]
4 print(i)
5 print(input_x.shape)
/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py in __next__(self)
200 self.reorder_dict[idx] = batch
201 continue
--> 202 return self._process_next_batch(batch)
203
204 next = __next__ # Python 2 compatibility
/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
220 self._put_indices()
221 if isinstance(batch, ExceptionWrapper):
--> 222 raise batch.exc_type(batch.exc_msg)
223 return batch
224
AttributeError: Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 41, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 41, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/datasets/folder.py", line 118, in __getitem__
img = self.transform(img)
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 369, in __call__
img = t(img)
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 706, in __call__
i, j, h, w = self.get_params(img)
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 693, in get_params
w = min(img.size[0], img.shape[1])
AttributeError: 'Image' object has no attribute 'shape'
transforms.RandomSizedCrop.get_params()
中存在错误。在错误消息的最后一行,它应该是 img.size
而不是 img.shape
.
仅当裁剪连续 10 次失败(返回到中央裁剪)时,才会执行包含错误的行。这就是为什么每批图像都不会出现此错误的原因。
我已经提交了 PR 来修复它。为了快速修复,您可以编辑 /usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py
文件并将所有 img.shape
更改为 img.size
。
编辑: PR 已合并。您可以在 GitHub 上安装最新的 torchvision
来修复它。
我正在使用基于ImageNet training in PyTorch的代码对resnet152进行微调,加载数据时出现错误,并且是在处理了几批图像后才出现的。我该如何解决这个问题。 以下代码是产生相同错误的简化代码:
代码
# Data loading code
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(train_img_dir, transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=256, shuffle=True,
num_workers=1, pin_memory=True)
for i, (input_x, target) in enumerate(train_loader):
if i % 10 == 0:
print(i)
print(input_x.shape)
print(target.shape)
错误
0
torch.Size([256, 3, 224, 224])
torch.Size([256])
10
torch.Size([256, 3, 224, 224])
torch.Size([256])
20
torch.Size([256, 3, 224, 224])
torch.Size([256])
30
torch.Size([256, 3, 224, 224])
torch.Size([256])
----------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-48-792d6ca206df> in <module>()
----> 1 for i, (input_x, target) in enumerate(train_loader):
2 if i % 10 == 0:
3 # sample_img = input_x[0]
4 print(i)
5 print(input_x.shape)
/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py in __next__(self)
200 self.reorder_dict[idx] = batch
201 continue
--> 202 return self._process_next_batch(batch)
203
204 next = __next__ # Python 2 compatibility
/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
220 self._put_indices()
221 if isinstance(batch, ExceptionWrapper):
--> 222 raise batch.exc_type(batch.exc_msg)
223 return batch
224
AttributeError: Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 41, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 41, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/datasets/folder.py", line 118, in __getitem__
img = self.transform(img)
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 369, in __call__
img = t(img)
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 706, in __call__
i, j, h, w = self.get_params(img)
File "/usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py", line 693, in get_params
w = min(img.size[0], img.shape[1])
AttributeError: 'Image' object has no attribute 'shape'
transforms.RandomSizedCrop.get_params()
中存在错误。在错误消息的最后一行,它应该是 img.size
而不是 img.shape
.
仅当裁剪连续 10 次失败(返回到中央裁剪)时,才会执行包含错误的行。这就是为什么每批图像都不会出现此错误的原因。
我已经提交了 PR 来修复它。为了快速修复,您可以编辑 /usr/local/lib/python3.5/dist-packages/torchvision-0.1.9-py3.5.egg/torchvision/transforms.py
文件并将所有 img.shape
更改为 img.size
。
编辑: PR 已合并。您可以在 GitHub 上安装最新的 torchvision
来修复它。