使用 Pytorch 显示每个 class 的图像数量

Display number of images per class using Pytorch

我将 Pytorch 与 FashionMNIST 数据集一起使用我想显示 10 个 classes 中的每一个的 8 个图像样本。但是,我没有想出如何将训练测试拆分为 train_labels,因为我需要在标签 (class) 上循环并打印每个 class 中的 8 个。 知道如何实现吗?

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                              #  transforms.Lambda(lambda x: x.repeat(3,1,1)),
                                transforms.Normalize((0.5, ), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True)


print('Training set size:', len(trainset))
print('Test set size:',len(testset))

如果我对你的理解正确,你想按标签对数据集进行分组然后显示它们。

您可以先构建一个字典来按标签存储示例:

examples = {i: [] for i in range(len(classes))}

然后遍历训练集并使用标签的索引追加到列表中:

for x, i in trainset:
    examples[i].append(x)

但是,这将涵盖整个系列。如果你想提前停止并避免收集超过 8 per-class 你可以通过添加条件来做到这一点:

n_examples = 8
for x, i in trainset:
    if all([len(ex) == n_examples for ex in examples.values()])
        break
    if len(examples[i]) < n_examples:
        examples[i].append(x)

只剩下显示 torchvision.transforms.ToPILImage:

transforms.ToPILImage()(examples[3][0])

如果要显示多个,可以使用两个连续的 torch.cat,一个在 dim=1(按行)然后在 dim=2(按列)创建一个网格.

grid = torch.cat([torch.cat(examples[i], dim=1) for i in range(len(classes))], dim=2)
transforms.ToPILImage()(grid)

可能的结果: