PyTorch - 以图像作为标签导入数据集
PyTorch - Import dataset with images as labels
我有一个包含图像作为输入和 labels/targets 作为图像的数据集。文件夹内结构如下:
> DATASET/
> ---TRAIN/
> ------image_xx.png
> ------label_xx.png
> ---TEST/
> ------image_xx.png
> ------label_xx.png
我目前尝试使用 torchvisions 数据集中的“ImageFolder”来加载图像,如下所示:
TRAIN_PATH = '/path/to/dataset/DATASET'
train_data = datasets.ImageFolder(root=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
然而如下图:
for img, label in train_loader:
print(img.shape)
print(label.shape)
break
torch.Size([16, 3, 128, 128])
torch.Size([16])
标签不是图像,而是标记或类似的东西。有没有一种方便的方法可以导入具有上述结构的数据集?
当每个图像具有离散的标量 class 时,ImageFolder
数据集是合适的。它期望目录结构是这样的,即每个子目录都包含某个 class.
对于您的情况,您可以简单地定义您自己的 torch.nn.Dataset
的子class。 This tutorial 可能会有帮助。
一个简单的例子(我没试过运行看看是否能正常工作):
import torch
import os
import cv2
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root_path, transform=None):
self.data_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("image")]
self.label_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("label")]
self.transform = transform
def __getitem__(self, idx):
img = cv2.imread(self.data_paths[idx])
label = cv2.imread(self.label_paths[idx])
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data_paths)
TRAIN_PATH = '/path/to/dataset/DATASET/TRAIN/'
train_data = MyDataset(root_path=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
我有一个包含图像作为输入和 labels/targets 作为图像的数据集。文件夹内结构如下:
> DATASET/
> ---TRAIN/
> ------image_xx.png
> ------label_xx.png
> ---TEST/
> ------image_xx.png
> ------label_xx.png
我目前尝试使用 torchvisions 数据集中的“ImageFolder”来加载图像,如下所示:
TRAIN_PATH = '/path/to/dataset/DATASET'
train_data = datasets.ImageFolder(root=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
然而如下图:
for img, label in train_loader:
print(img.shape)
print(label.shape)
break
torch.Size([16, 3, 128, 128])
torch.Size([16])
标签不是图像,而是标记或类似的东西。有没有一种方便的方法可以导入具有上述结构的数据集?
当每个图像具有离散的标量 class 时,ImageFolder
数据集是合适的。它期望目录结构是这样的,即每个子目录都包含某个 class.
对于您的情况,您可以简单地定义您自己的 torch.nn.Dataset
的子class。 This tutorial 可能会有帮助。
一个简单的例子(我没试过运行看看是否能正常工作):
import torch
import os
import cv2
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root_path, transform=None):
self.data_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("image")]
self.label_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("label")]
self.transform = transform
def __getitem__(self, idx):
img = cv2.imread(self.data_paths[idx])
label = cv2.imread(self.label_paths[idx])
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data_paths)
TRAIN_PATH = '/path/to/dataset/DATASET/TRAIN/'
train_data = MyDataset(root_path=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)