用于 Image GT 数据集的 Pytorch Dataloader
Pytorch Dataloader for Image GT dataset
我是pytorch的新手。我正在尝试为图像数据集创建一个 DataLoader,其中每个图像都有相应的基本事实(同名):
root:
--->RGB:
------>img1.png
------>img2.png
------>...
------>imgN.png
--->GT:
------>img1.png
------>img2.png
------>...
------>imgN.png
当我使用根文件夹(包含 RGB 和 GT 文件夹)的路径作为 torchvision.datasets.ImageFolder
的输入时,它会读取所有图像,就好像它们都是用于输入(分类为 RGB 和 GT ), 似乎没有办法配对 RGB-GT 图像。我想对 RGB-GT 图像进行配对、洗牌并将其分成定义大小的批次。如何做呢?任何建议将被认真考虑。
谢谢
我认为,好的起点是以VisionDataset
class为基础。我们这里要用到的是:DatasetFolder source code. So, we going to create smth similar. You can notice this class depends on two other functions from datasets.folder
module: default_loader and make_dataset.
我们不打算修改default_loader
,因为它已经很好了,它只是帮助我们加载图像,所以我们将导入它。
但是我们需要一个新的 make_dataset
函数,它可以从根文件夹中准备正确的图像对。由于原始 make_dataset
对图像(图像路径,如果更准确的话)及其根文件夹作为目标 class(class 索引),我们有一个 (path, class_to_idx[target])
对列表,但我们需要 (rgb_path, gt_path)
。这是新 make_dataset
:
的代码
def make_dataset(root: str) -> list:
"""Reads a directory with data.
Returns a dataset as a list of tuples of paired image paths: (rgb_path, gt_path)
"""
dataset = []
# Our dir names
rgb_dir = 'RGB'
gt_dir = 'GT'
# Get all the filenames from RGB folder
rgb_fnames = sorted(os.listdir(os.path.join(root, rgb_dir)))
# Compare file names from GT folder to file names from RGB:
for gt_fname in sorted(os.listdir(os.path.join(root, gt_dir))):
if gt_fname in rgb_fnames:
# if we have a match - create pair of full path to the corresponding images
rgb_path = os.path.join(root, rgb_dir, gt_fname)
gt_path = os.path.join(root, gt_dir, gt_fname)
item = (rgb_path, gt_path)
# append to the list dataset
dataset.append(item)
else:
continue
return dataset
我们现在有什么?让我们将我们的功能与原始功能进行比较:
from torchvision.datasets.folder import make_dataset as make_dataset_original
dataset_original = make_dataset_original(root, {'RGB': 0, 'GT': 1}, extensions='png')
dataset = make_dataset(root)
print('Original make_dataset:')
print(*dataset_original, sep='\n')
print('Our make_dataset:')
print(*dataset, sep='\n')
Original make_dataset:
('./data/GT/img1.png', 1)
('./data/GT/img2.png', 1)
...
('./data/RGB/img1.png', 0)
('./data/RGB/img2.png', 0)
...
Our make_dataset:
('./data/RGB/img1.png', './data/GT/img1.png')
('./data/RGB/img2.png', './data/GT/img2.png')
...
我认为效果很好)是时候创建我们的 class 数据集了。这里最重要的部分是 __getitem__
方法,因为它导入图像,应用转换和 returns 一个张量,可以被数据加载器使用。我们需要读取一对图像(rgb 和 gt)和 return 2 个张量图像的元组:
from torchvision.datasets.folder import default_loader
from torchvision.datasets.vision import VisionDataset
class CustomVisionDataset(VisionDataset):
def __init__(self,
root,
loader=default_loader,
rgb_transform=None,
gt_transform=None):
super().__init__(root,
transform=rgb_transform,
target_transform=gt_transform)
# Prepare dataset
samples = make_dataset(self.root)
self.loader = loader
self.samples = samples
# list of RGB images
self.rgb_samples = [s[1] for s in samples]
# list of GT images
self.gt_samples = [s[1] for s in samples]
def __getitem__(self, index):
"""Returns a data sample from our dataset.
"""
# getting our paths to images
rgb_path, gt_path = self.samples[index]
# import each image using loader (by default it's PIL)
rgb_sample = self.loader(rgb_path)
gt_sample = self.loader(gt_path)
# here goes tranforms if needed
# maybe we need different tranforms for each type of image
if self.transform is not None:
rgb_sample = self.transform(rgb_sample)
if self.target_transform is not None:
gt_sample = self.target_transform(gt_sample)
# now we return the right imported pair of images (tensors)
return rgb_sample, gt_sample
def __len__(self):
return len(self.samples)
我们来测试一下:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
bs=4 # batch size
transforms = ToTensor() # we need this to convert PIL images to Tensor
shuffle = True
dataset = CustomVisionDataset('./data', rgb_transform=transforms, gt_transform=transforms)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=shuffle)
for i, (rgb, gt) in enumerate(dataloader):
print(f'batch {i+1}:')
# some plots
for i in range(bs):
plt.figure(figsize=(10, 5))
plt.subplot(221)
plt.imshow(rgb[i].squeeze().permute(1, 2, 0))
plt.title(f'RGB img{i+1}')
plt.subplot(222)
plt.imshow(gt[i].squeeze().permute(1, 2, 0))
plt.title(f'GT img{i+1}')
plt.show()
输出:
batch 1:
...
Here 你可以找到一个带有代码和简单虚拟数据集的笔记本。
我是pytorch的新手。我正在尝试为图像数据集创建一个 DataLoader,其中每个图像都有相应的基本事实(同名):
root:
--->RGB:
------>img1.png
------>img2.png
------>...
------>imgN.png
--->GT:
------>img1.png
------>img2.png
------>...
------>imgN.png
当我使用根文件夹(包含 RGB 和 GT 文件夹)的路径作为 torchvision.datasets.ImageFolder
的输入时,它会读取所有图像,就好像它们都是用于输入(分类为 RGB 和 GT ), 似乎没有办法配对 RGB-GT 图像。我想对 RGB-GT 图像进行配对、洗牌并将其分成定义大小的批次。如何做呢?任何建议将被认真考虑。
谢谢
我认为,好的起点是以VisionDataset
class为基础。我们这里要用到的是:DatasetFolder source code. So, we going to create smth similar. You can notice this class depends on two other functions from datasets.folder
module: default_loader and make_dataset.
我们不打算修改default_loader
,因为它已经很好了,它只是帮助我们加载图像,所以我们将导入它。
但是我们需要一个新的 make_dataset
函数,它可以从根文件夹中准备正确的图像对。由于原始 make_dataset
对图像(图像路径,如果更准确的话)及其根文件夹作为目标 class(class 索引),我们有一个 (path, class_to_idx[target])
对列表,但我们需要 (rgb_path, gt_path)
。这是新 make_dataset
:
def make_dataset(root: str) -> list:
"""Reads a directory with data.
Returns a dataset as a list of tuples of paired image paths: (rgb_path, gt_path)
"""
dataset = []
# Our dir names
rgb_dir = 'RGB'
gt_dir = 'GT'
# Get all the filenames from RGB folder
rgb_fnames = sorted(os.listdir(os.path.join(root, rgb_dir)))
# Compare file names from GT folder to file names from RGB:
for gt_fname in sorted(os.listdir(os.path.join(root, gt_dir))):
if gt_fname in rgb_fnames:
# if we have a match - create pair of full path to the corresponding images
rgb_path = os.path.join(root, rgb_dir, gt_fname)
gt_path = os.path.join(root, gt_dir, gt_fname)
item = (rgb_path, gt_path)
# append to the list dataset
dataset.append(item)
else:
continue
return dataset
我们现在有什么?让我们将我们的功能与原始功能进行比较:
from torchvision.datasets.folder import make_dataset as make_dataset_original
dataset_original = make_dataset_original(root, {'RGB': 0, 'GT': 1}, extensions='png')
dataset = make_dataset(root)
print('Original make_dataset:')
print(*dataset_original, sep='\n')
print('Our make_dataset:')
print(*dataset, sep='\n')
Original make_dataset:
('./data/GT/img1.png', 1)
('./data/GT/img2.png', 1)
...
('./data/RGB/img1.png', 0)
('./data/RGB/img2.png', 0)
...
Our make_dataset:
('./data/RGB/img1.png', './data/GT/img1.png')
('./data/RGB/img2.png', './data/GT/img2.png')
...
我认为效果很好)是时候创建我们的 class 数据集了。这里最重要的部分是 __getitem__
方法,因为它导入图像,应用转换和 returns 一个张量,可以被数据加载器使用。我们需要读取一对图像(rgb 和 gt)和 return 2 个张量图像的元组:
from torchvision.datasets.folder import default_loader
from torchvision.datasets.vision import VisionDataset
class CustomVisionDataset(VisionDataset):
def __init__(self,
root,
loader=default_loader,
rgb_transform=None,
gt_transform=None):
super().__init__(root,
transform=rgb_transform,
target_transform=gt_transform)
# Prepare dataset
samples = make_dataset(self.root)
self.loader = loader
self.samples = samples
# list of RGB images
self.rgb_samples = [s[1] for s in samples]
# list of GT images
self.gt_samples = [s[1] for s in samples]
def __getitem__(self, index):
"""Returns a data sample from our dataset.
"""
# getting our paths to images
rgb_path, gt_path = self.samples[index]
# import each image using loader (by default it's PIL)
rgb_sample = self.loader(rgb_path)
gt_sample = self.loader(gt_path)
# here goes tranforms if needed
# maybe we need different tranforms for each type of image
if self.transform is not None:
rgb_sample = self.transform(rgb_sample)
if self.target_transform is not None:
gt_sample = self.target_transform(gt_sample)
# now we return the right imported pair of images (tensors)
return rgb_sample, gt_sample
def __len__(self):
return len(self.samples)
我们来测试一下:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
bs=4 # batch size
transforms = ToTensor() # we need this to convert PIL images to Tensor
shuffle = True
dataset = CustomVisionDataset('./data', rgb_transform=transforms, gt_transform=transforms)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=shuffle)
for i, (rgb, gt) in enumerate(dataloader):
print(f'batch {i+1}:')
# some plots
for i in range(bs):
plt.figure(figsize=(10, 5))
plt.subplot(221)
plt.imshow(rgb[i].squeeze().permute(1, 2, 0))
plt.title(f'RGB img{i+1}')
plt.subplot(222)
plt.imshow(gt[i].squeeze().permute(1, 2, 0))
plt.title(f'GT img{i+1}')
plt.show()
输出:
batch 1:
...
Here 你可以找到一个带有代码和简单虚拟数据集的笔记本。