RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
这是我在尝试训练我的网络时遇到的错误。
我们用来存储来自 Caltech 101 数据集的图像的 class 是由我们的老师提供给我们的。
from torchvision.datasets import VisionDataset
from PIL import Image
import os
import os.path
import sys
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class Caltech(VisionDataset):
def __init__(self, root, split='train', transform=None, target_transform=None):
super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)
self.split = split # This defines the split you are going to use
# (split files are called 'train.txt' and 'test.txt')
'''
- Here you should implement the logic for reading the splits files and accessing elements
- If the RAM size allows it, it is faster to store all data in memory
- PyTorch Dataset classes use indexes to read elements
- You should provide a way for the __getitem__ method to access the image-label pair
through the index
- Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class)
'''
# Open file in read only mode and read all lines
file = open(self.split, "r")
lines = file.readlines()
# Filter out the lines which start with 'BACKGROUND_Google' as asked in the homework
self.elements = [i for i in lines if not i.startswith('BACKGROUND_Google')]
# Delete BACKGROUND_Google class from dataset labels
self.classes = sorted(os.listdir(os.path.join(self.root, "")))
self.classes.remove("BACKGROUND_Google")
def __getitem__(self, index):
'''
__getitem__ should access an element through its index
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
'''
img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))
target = self.classes.index(self.elements[index].rstrip().split('/')[0])
image, label = img, target # Provide a way to access image and label via index
# Image should be a PIL Image
# label can be int
# Applies preprocessing when accessing the image
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
'''
The __len__ method returns the length of the dataset
It is mandatory, as this is used by several other components
'''
# Provides a way to get the length (number of elements) of the dataset
length = len(self.elements)
return length
而预处理阶段是通过以下代码完成的:
# Define transforms for training phase
train_transform = transforms.Compose([transforms.Resize(256), # Resizes short size of the PIL image to 256
transforms.CenterCrop(224), # Crops a central square patch of the image
# 224 because torchvision's AlexNet needs a 224x224 input!
# Remember this when applying different transformations, otherwise you get an error
transforms.ToTensor(), # Turn PIL Image to torch.Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalizes tensor with mean and standard deviation
])
# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
最后就是数据集和dataloader的准备:
# Clone github repository with data
if not os.path.isdir('./Homework2-Caltech101'):
!git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git
# Commands to execute when there is an error saying no file or directory related to ./Homework2-Caltech101/
# !rm -r ./Homework2-Caltech101/
# !git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git
DATA_DIR = 'Homework2-Caltech101/101_ObjectCategories'
SPLIT_TRAIN = 'Homework2-Caltech101/train.txt'
SPLIT_TEST = 'Homework2-Caltech101/test.txt'
# 1 - Data preparation
myTrainDS = Caltech(DATA_DIR, split = SPLIT_TRAIN, transform=train_transform)
myTestDS = Caltech(DATA_DIR, split = SPLIT_TEST, transform=eval_transform)
print('My Train DS: {}'.format(len(myTrainDS)))
print('My Test DS: {}'.format(len(myTestDS)))
# 1 - Data preparation
myTrain_dataloader = DataLoader(myTrainDS, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
myTest_dataloader = DataLoader(myTestDS, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
好的,现在这两个 .txt 文件包含我们想要在训练和测试拆分中拥有的图像列表,所以我们必须从那里获取它们,但应该已经正确完成了。问题是,当我接近我的训练阶段(见后面的代码)时,我会看到标题中的错误。我已经尝试在转换函数中添加以下行:
[...]
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
在 centercrop 之后,但它说图像没有属性重复,所以我有点卡住了。
给我错误的训练代码行如下:
# Iterate over the dataset
for images, labels in myTrain_dataloader:
如果需要,完整的错误是:
RuntimeError Traceback (most recent call last)
<ipython-input-197-0e4710a9855d> in <module>()
47
48 # Iterate over the dataset
---> 49 for images, labels in myTrain_dataloader:
50
51 # Bring data over the device of choice
2 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
817 else:
818 del self._task_info[idx]
--> 819 return self._process_data(data)
820
821 next = __next__ # Python 2 compatibility
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
844 self._try_put_index()
845 if isinstance(data, ExceptionWrapper):
--> 846 data.reraise()
847 return data
848
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
383 # (https://bugs.python.org/issue2651), so we work around it.
384 msg = KeyErrorMessage(msg)
--> 385 raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-180-0b00b175e18c>", line 72, in __getitem__
image = self.transform(image)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 70, in __call__
img = t(img)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 175, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 217, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
我正在使用 Alexnet,我得到的代码如下:
net = alexnet() # Loading AlexNet model
# AlexNet has 1000 output neurons, corresponding to the 1000 ImageNet's classes
# We need 101 outputs for Caltech-101
net.classifier[6] = nn.Linear(4096, NUM_CLASSES) # nn.Linear in pytorch is a fully connected layer
# The convolutional layer is nn.Conv2d
# We just changed the last layer of AlexNet with a new fully connected layer with 101 outputs
# It is mandatory to study torchvision.models.alexnet source code
张量的第一维表示颜色,所以你的错误意味着你给出的是灰度图片(1 通道),而数据加载器需要 RGB 图像(3 通道)。您定义了一个 pil_loader 函数,该函数 returns RGB 图像,但您从未使用它。
所以你有两个选择:
使用灰度图像而不是 rgb,这在计算上更便宜。
解决方案: 在训练和测试转换中都将 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
更改为 transforms.Normalize((0.5), (0.5))
确保您的图像是 rgb 格式的。我不知道你的图像是如何存储的,但我猜你下载的是灰度数据集。您可以尝试的一件事是使用您定义的 pil_loader 函数。尝试在 __getitem__
函数中将 img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))
更改为 img = pil_loader(os.path.join(self.root, self.elements[index].rstrip()))
。
让我知道进展如何!
这是我在尝试训练我的网络时遇到的错误。
我们用来存储来自 Caltech 101 数据集的图像的 class 是由我们的老师提供给我们的。
from torchvision.datasets import VisionDataset
from PIL import Image
import os
import os.path
import sys
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class Caltech(VisionDataset):
def __init__(self, root, split='train', transform=None, target_transform=None):
super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)
self.split = split # This defines the split you are going to use
# (split files are called 'train.txt' and 'test.txt')
'''
- Here you should implement the logic for reading the splits files and accessing elements
- If the RAM size allows it, it is faster to store all data in memory
- PyTorch Dataset classes use indexes to read elements
- You should provide a way for the __getitem__ method to access the image-label pair
through the index
- Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class)
'''
# Open file in read only mode and read all lines
file = open(self.split, "r")
lines = file.readlines()
# Filter out the lines which start with 'BACKGROUND_Google' as asked in the homework
self.elements = [i for i in lines if not i.startswith('BACKGROUND_Google')]
# Delete BACKGROUND_Google class from dataset labels
self.classes = sorted(os.listdir(os.path.join(self.root, "")))
self.classes.remove("BACKGROUND_Google")
def __getitem__(self, index):
'''
__getitem__ should access an element through its index
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
'''
img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))
target = self.classes.index(self.elements[index].rstrip().split('/')[0])
image, label = img, target # Provide a way to access image and label via index
# Image should be a PIL Image
# label can be int
# Applies preprocessing when accessing the image
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
'''
The __len__ method returns the length of the dataset
It is mandatory, as this is used by several other components
'''
# Provides a way to get the length (number of elements) of the dataset
length = len(self.elements)
return length
而预处理阶段是通过以下代码完成的:
# Define transforms for training phase
train_transform = transforms.Compose([transforms.Resize(256), # Resizes short size of the PIL image to 256
transforms.CenterCrop(224), # Crops a central square patch of the image
# 224 because torchvision's AlexNet needs a 224x224 input!
# Remember this when applying different transformations, otherwise you get an error
transforms.ToTensor(), # Turn PIL Image to torch.Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalizes tensor with mean and standard deviation
])
# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
最后就是数据集和dataloader的准备:
# Clone github repository with data
if not os.path.isdir('./Homework2-Caltech101'):
!git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git
# Commands to execute when there is an error saying no file or directory related to ./Homework2-Caltech101/
# !rm -r ./Homework2-Caltech101/
# !git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git
DATA_DIR = 'Homework2-Caltech101/101_ObjectCategories'
SPLIT_TRAIN = 'Homework2-Caltech101/train.txt'
SPLIT_TEST = 'Homework2-Caltech101/test.txt'
# 1 - Data preparation
myTrainDS = Caltech(DATA_DIR, split = SPLIT_TRAIN, transform=train_transform)
myTestDS = Caltech(DATA_DIR, split = SPLIT_TEST, transform=eval_transform)
print('My Train DS: {}'.format(len(myTrainDS)))
print('My Test DS: {}'.format(len(myTestDS)))
# 1 - Data preparation
myTrain_dataloader = DataLoader(myTrainDS, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
myTest_dataloader = DataLoader(myTestDS, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
好的,现在这两个 .txt 文件包含我们想要在训练和测试拆分中拥有的图像列表,所以我们必须从那里获取它们,但应该已经正确完成了。问题是,当我接近我的训练阶段(见后面的代码)时,我会看到标题中的错误。我已经尝试在转换函数中添加以下行:
[...]
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
在 centercrop 之后,但它说图像没有属性重复,所以我有点卡住了。
给我错误的训练代码行如下:
# Iterate over the dataset
for images, labels in myTrain_dataloader:
如果需要,完整的错误是:
RuntimeError Traceback (most recent call last)
<ipython-input-197-0e4710a9855d> in <module>()
47
48 # Iterate over the dataset
---> 49 for images, labels in myTrain_dataloader:
50
51 # Bring data over the device of choice
2 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
817 else:
818 del self._task_info[idx]
--> 819 return self._process_data(data)
820
821 next = __next__ # Python 2 compatibility
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
844 self._try_put_index()
845 if isinstance(data, ExceptionWrapper):
--> 846 data.reraise()
847 return data
848
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
383 # (https://bugs.python.org/issue2651), so we work around it.
384 msg = KeyErrorMessage(msg)
--> 385 raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-180-0b00b175e18c>", line 72, in __getitem__
image = self.transform(image)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 70, in __call__
img = t(img)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 175, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 217, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
我正在使用 Alexnet,我得到的代码如下:
net = alexnet() # Loading AlexNet model
# AlexNet has 1000 output neurons, corresponding to the 1000 ImageNet's classes
# We need 101 outputs for Caltech-101
net.classifier[6] = nn.Linear(4096, NUM_CLASSES) # nn.Linear in pytorch is a fully connected layer
# The convolutional layer is nn.Conv2d
# We just changed the last layer of AlexNet with a new fully connected layer with 101 outputs
# It is mandatory to study torchvision.models.alexnet source code
张量的第一维表示颜色,所以你的错误意味着你给出的是灰度图片(1 通道),而数据加载器需要 RGB 图像(3 通道)。您定义了一个 pil_loader 函数,该函数 returns RGB 图像,但您从未使用它。
所以你有两个选择:
使用灰度图像而不是 rgb,这在计算上更便宜。 解决方案: 在训练和测试转换中都将
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
更改为transforms.Normalize((0.5), (0.5))
确保您的图像是 rgb 格式的。我不知道你的图像是如何存储的,但我猜你下载的是灰度数据集。您可以尝试的一件事是使用您定义的 pil_loader 函数。尝试在
__getitem__
函数中将img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))
更改为img = pil_loader(os.path.join(self.root, self.elements[index].rstrip()))
。
让我知道进展如何!