PyTorch - 使用 torchvision.datasets.ImageFolder 标记不正确
PyTorch - Incorrect labeling using torchvision.datasets.ImageFolder
我按以下方式构建了我的数据集:
dataset/train/0/456.jpg
dataset/train/1/456456.jpg
dataset/train/2/456.jpg
dataset/train/...
dataset/val/0/878.jpg
dataset/val/1/234.jpg
dataset/val/2/34554.jpg
dataset/val/...
所以我使用 torchvision.datasets.ImageFolder
将我的数据集导入 PyTorch。但是,它似乎没有为正确的图像提供正确的标签。我在下面添加了我的代码:
data_transforms = {
'train': transforms.Compose(
[transforms.Resize((176,176)),
transforms.RandomRotation((0,360)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.CenterCrop(128),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
]),
'val': transforms.Compose(
[transforms.Resize((128,128)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
]),
}
data_dir = 'dataset'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
我使用以下函数发现标签是错误的:
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(dataloaders['val'])
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)
使用显示的图像和标签,我手动检查它们是否正确。不幸的是,标签与图像不对应。有人可以告诉我我做错了什么吗?
ImageFolder API 假定您的数据位于 "predefined" 文件夹结构中。
请检查 PyTorch 代码或文档中的以下注释 @ https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder
A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
这意味着,您需要将数据安排在与您的标签匹配的文件夹下。在上面的例子中有 2 个标签,猫和狗。
希望对您有所帮助!
有人帮我解决了这个问题。 ImageFolder 创建自己的内部标签。通过打印 image_datasets['train'].class_to_idx
你可以看到什么标签与什么内部标签配对。使用这本词典,您可以追溯原始标签。
我按以下方式构建了我的数据集:
dataset/train/0/456.jpg
dataset/train/1/456456.jpg
dataset/train/2/456.jpg
dataset/train/...
dataset/val/0/878.jpg
dataset/val/1/234.jpg
dataset/val/2/34554.jpg
dataset/val/...
所以我使用 torchvision.datasets.ImageFolder
将我的数据集导入 PyTorch。但是,它似乎没有为正确的图像提供正确的标签。我在下面添加了我的代码:
data_transforms = {
'train': transforms.Compose(
[transforms.Resize((176,176)),
transforms.RandomRotation((0,360)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.CenterCrop(128),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
]),
'val': transforms.Compose(
[transforms.Resize((128,128)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
]),
}
data_dir = 'dataset'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
我使用以下函数发现标签是错误的:
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(dataloaders['val'])
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)
使用显示的图像和标签,我手动检查它们是否正确。不幸的是,标签与图像不对应。有人可以告诉我我做错了什么吗?
ImageFolder API 假定您的数据位于 "predefined" 文件夹结构中。 请检查 PyTorch 代码或文档中的以下注释 @ https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder
A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
这意味着,您需要将数据安排在与您的标签匹配的文件夹下。在上面的例子中有 2 个标签,猫和狗。
希望对您有所帮助!
有人帮我解决了这个问题。 ImageFolder 创建自己的内部标签。通过打印 image_datasets['train'].class_to_idx
你可以看到什么标签与什么内部标签配对。使用这本词典,您可以追溯原始标签。