从现有的 Torchvision 数据集创建简化的数据集
Creating reduced Dataset from existing Torchvision Dataset
我们都知道常见的 MNIST 数据集,包含在 torchvision.datasets
包中。想象一下,我想创建这个数据集的简化版本,只包含 1 和 0 到 class 仅验证这两个数字而不是所有10 个值。
我看到自定义数据集可以在继承所需数据集的 class 中创建,因此 __getitem__
,其中 return 是给定索引处的项目。所以我这样做了:
class MNIST01(MNIST):
def __getitem__(self, idx):
image, label = super().__getitem__(idx)
if label.item() <= 1:
return image, label
else:
return None
问题是我似乎无法 return 一个 None 值,因为它需要 "contain tensors, numbers, dicts or lists; found class 'NoneType'"。
是否有一种简单的方法可以以类似的方式轻松获得此数据集的缩减版本?
我终于设法解决了 NoneType 问题。保留问题中定义的功能。
class MNIST01(MNIST):
def __getitem__(self, idx):
features, target = super(MNIST01, self).__getitem__(idx)
if target.item() <= 1:
return features, target
我们现在需要为我们的数据加载器定义一个自定义 collate function collate_fn
,它处理样本列表以形成一个批次。在这个函数中,我们可以应用过滤器来处理 None
值并忽略它们。
from torch.utils.data.dataloader import default_collate
def filter_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
那么我们只需要将这个函数传递给DataLoader
:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)
版本 2
比第一个简单多了,避免了一些访问数据的问题。只需从 MNIST
class.
的实例化中直接过滤 train_data
和 train_label
属性(以及对应的测试集)
train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]
train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]
我们都知道常见的 MNIST 数据集,包含在 torchvision.datasets
包中。想象一下,我想创建这个数据集的简化版本,只包含 1 和 0 到 class 仅验证这两个数字而不是所有10 个值。
我看到自定义数据集可以在继承所需数据集的 class 中创建,因此 __getitem__
,其中 return 是给定索引处的项目。所以我这样做了:
class MNIST01(MNIST):
def __getitem__(self, idx):
image, label = super().__getitem__(idx)
if label.item() <= 1:
return image, label
else:
return None
问题是我似乎无法 return 一个 None 值,因为它需要 "contain tensors, numbers, dicts or lists; found class 'NoneType'"。
是否有一种简单的方法可以以类似的方式轻松获得此数据集的缩减版本?
我终于设法解决了 NoneType 问题。保留问题中定义的功能。
class MNIST01(MNIST):
def __getitem__(self, idx):
features, target = super(MNIST01, self).__getitem__(idx)
if target.item() <= 1:
return features, target
我们现在需要为我们的数据加载器定义一个自定义 collate function collate_fn
,它处理样本列表以形成一个批次。在这个函数中,我们可以应用过滤器来处理 None
值并忽略它们。
from torch.utils.data.dataloader import default_collate
def filter_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
那么我们只需要将这个函数传递给DataLoader
:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)
版本 2
比第一个简单多了,避免了一些访问数据的问题。只需从 MNIST
class.
train_data
和 train_label
属性(以及对应的测试集)
train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]
train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]