如何使用pytorch在cifar10或stl10中加载一种类型的图像
how to load one type of image in cifar10 or stl10 with pytorch
这是一个非常简单的问题,我只是想从标准的 pytorch 图像数据集中 select 特定的 class 图像(例如“汽车”)。目前数据加载器看起来像这样:
def cycle(iterable):
while True:
for x in iterable:
yield x
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])),
shuffle=True, batch_size=8)
train_iterator = iter(cycle(train_loader))
class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
train_iterator = iter(cycle(train_loader))
迭代器returns 一批所有类型的混洗图像,但我希望能够select 返回什么类型的图像,例如。只是鹿或船的图像
完成!
def cycle(iterable):
while True:
for x in iterable:
yield x
# Return only images of certain class (eg. aeroplanes = class 0)
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
# STL10 dataset
train_dataset = torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()]))
label_class = 1# birds
# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)
bird_set = torch.utils.data.Subset(train_dataset, train_indices)
train_loader = torch.utils.data.DataLoader(dataset=bird_set, shuffle=True,
batch_size=batch_size, drop_last=True)
train_iterator = iter(cycle(train_loader))
这是一个非常简单的问题,我只是想从标准的 pytorch 图像数据集中 select 特定的 class 图像(例如“汽车”)。目前数据加载器看起来像这样:
def cycle(iterable):
while True:
for x in iterable:
yield x
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])),
shuffle=True, batch_size=8)
train_iterator = iter(cycle(train_loader))
class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
train_iterator = iter(cycle(train_loader))
迭代器returns 一批所有类型的混洗图像,但我希望能够select 返回什么类型的图像,例如。只是鹿或船的图像
完成!
def cycle(iterable):
while True:
for x in iterable:
yield x
# Return only images of certain class (eg. aeroplanes = class 0)
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
# STL10 dataset
train_dataset = torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()]))
label_class = 1# birds
# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)
bird_set = torch.utils.data.Subset(train_dataset, train_indices)
train_loader = torch.utils.data.DataLoader(dataset=bird_set, shuffle=True,
batch_size=batch_size, drop_last=True)
train_iterator = iter(cycle(train_loader))