torchvision.datasets.cifar.CIFAR10 是不是一个列表?

Is torchvision.datasets.cifar.CIFAR10 a list or not?

当我试图找出里面的东西时 torchvision.datasets.cifar.CIFAR10,我做了一些简单的代码

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                    download=True, transform=transform)
print(trainset[1])
print(trainset[:10])
print(type(trainset))

但是,当我尝试

时出现了一些错误
print(trainset[:10])

错误信息是

TypeError: Cannot handle this data type

我想知道为什么我可以使用 trainset[1],但不能使用 trainset[:10]

CIFAR10 不支持切片,这就是您收到该错误的原因。如果你想要前 10 个,你将不得不这样做:

print([trainset[i] for i in range(10)])

更多信息

你可以索引 CIFAR10 实例的主要原因 class 是因为 class 实现了 __getitem__() 功能。

因此,当您调用 trainset[i] 时,您实际上是在调用 trainset.__getitem__(i)

现在,在 python3 中,切片表达式也通过 __getitem__() 处理,其中切片表达式作为切片对象传递给 __getitem__()

所以,trainset[2:10] 等同于 trainset.__getitem__(slice(2, 10))

并且由于将两种不同类型的对象传递给 __getitem__ 期望做完全不同的事情,你必须明确地处理它们。

不幸的是它不是,正如你从 CIFAR10 的 __getitem__ 方法实现中看到的 class:

def __getitem__(self, index):
    if self.train:
        img, target = self.train_data[index], self.train_labels[index]
    else:
        img, target = self.test_data[index], self.test_labels[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img)

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

除了 entrophys 答案之外,我建议使用 torch.utils.data.dataset.random_split 例如这:方式:

train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, lengths)
trainloader = DataLoader(train_data, 
  batch_size=args.train_batch,  
  shuffle=True, 
  num_workers=args.nThreads, 
  pin_memory=True)
validloader = DataLoader(valid_data, 
  batch_size=args.train_batch,  
  shuffle=True, 
  num_workers=args.nThreads, 
  pin_memory=True)

来源:https://yimjiyoung.github.io/2020/02/13/How-to-split-dataset-into-train-and-validation-set-in-pytorch/