如何使用pytorch在cifar10或stl10中加载一种类型的图像

how to load one type of image in cifar10 or stl10 with pytorch

这是一个非常简单的问题,我只是想从标准的 pytorch 图像数据集中 select 特定的 class 图像(例如“汽车”)。目前数据加载器看起来像这样:

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])),
shuffle=True, batch_size=8)
train_iterator = iter(cycle(train_loader))
class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']

train_iterator = iter(cycle(train_loader))

迭代器returns 一批所有类型的混洗图像,但我希望能够select 返回什么类型的图像,例如。只是鹿或船的图像

完成!

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

# Return only images of certain class (eg. aeroplanes = class 0)
def get_same_index(target, label):
    label_indices = []
    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)
    return label_indices

# STL10 dataset
train_dataset = torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()]))

label_class = 1# birds

# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)

bird_set = torch.utils.data.Subset(train_dataset, train_indices)

train_loader = torch.utils.data.DataLoader(dataset=bird_set, shuffle=True,
                                           batch_size=batch_size, drop_last=True)
train_iterator = iter(cycle(train_loader))