如何在pytorch中对数据集进行排序

How to sort a dataset in pytorch

我想按标签中的数值对我的数据集进行排序。

pytorch 是否有一个函数可以有效地处理这个问题?

我的数据集 type() 来自:

 <class 'torchvision.datasets.mnist.MNIST'>

没有通用的方法可以有效地做到这一点,因为数据集 class 只实现了 __getitem____len__ 方法,不一定有任何 "stored" 关于标签的信息。

MNIST dataset class 的情况下,您可以从标签列表中对数据集进行排序。

例如,当您想要列出标签为 5 的索引时。

mnist = torchvision.datasets.mnist.MNIST("/")
labels = mnist.train_labels
fives = (labels == 5).nonzero()