torchvision.datasets.cifar.CIFAR10 是不是一个列表?
Is torchvision.datasets.cifar.CIFAR10 a list or not?
当我试图找出里面的东西时 torchvision.datasets.cifar.CIFAR10,我做了一些简单的代码
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
print(trainset[1])
print(trainset[:10])
print(type(trainset))
但是,当我尝试
时出现了一些错误
print(trainset[:10])
错误信息是
TypeError: Cannot handle this data type
我想知道为什么我可以使用 trainset[1]
,但不能使用 trainset[:10]
?
CIFAR10 不支持切片,这就是您收到该错误的原因。如果你想要前 10 个,你将不得不这样做:
print([trainset[i] for i in range(10)])
更多信息
你可以索引 CIFAR10 实例的主要原因 class 是因为 class 实现了 __getitem__()
功能。
因此,当您调用 trainset[i]
时,您实际上是在调用 trainset.__getitem__(i)
现在,在 python3 中,切片表达式也通过 __getitem__()
处理,其中切片表达式作为切片对象传递给 __getitem__()
。
所以,trainset[2:10]
等同于 trainset.__getitem__(slice(2, 10))
并且由于将两种不同类型的对象传递给 __getitem__
期望做完全不同的事情,你必须明确地处理它们。
不幸的是它不是,正如你从 CIFAR10 的 __getitem__
方法实现中看到的 class:
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
除了 entrophys 答案之外,我建议使用 torch.utils.data.dataset.random_split 例如这:方式:
train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, lengths)
trainloader = DataLoader(train_data,
batch_size=args.train_batch,
shuffle=True,
num_workers=args.nThreads,
pin_memory=True)
validloader = DataLoader(valid_data,
batch_size=args.train_batch,
shuffle=True,
num_workers=args.nThreads,
pin_memory=True)
当我试图找出里面的东西时 torchvision.datasets.cifar.CIFAR10,我做了一些简单的代码
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
print(trainset[1])
print(trainset[:10])
print(type(trainset))
但是,当我尝试
时出现了一些错误print(trainset[:10])
错误信息是
TypeError: Cannot handle this data type
我想知道为什么我可以使用 trainset[1]
,但不能使用 trainset[:10]
?
CIFAR10 不支持切片,这就是您收到该错误的原因。如果你想要前 10 个,你将不得不这样做:
print([trainset[i] for i in range(10)])
更多信息
你可以索引 CIFAR10 实例的主要原因 class 是因为 class 实现了 __getitem__()
功能。
因此,当您调用 trainset[i]
时,您实际上是在调用 trainset.__getitem__(i)
现在,在 python3 中,切片表达式也通过 __getitem__()
处理,其中切片表达式作为切片对象传递给 __getitem__()
。
所以,trainset[2:10]
等同于 trainset.__getitem__(slice(2, 10))
并且由于将两种不同类型的对象传递给 __getitem__
期望做完全不同的事情,你必须明确地处理它们。
不幸的是它不是,正如你从 CIFAR10 的 __getitem__
方法实现中看到的 class:
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
除了
train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, lengths)
trainloader = DataLoader(train_data,
batch_size=args.train_batch,
shuffle=True,
num_workers=args.nThreads,
pin_memory=True)
validloader = DataLoader(valid_data,
batch_size=args.train_batch,
shuffle=True,
num_workers=args.nThreads,
pin_memory=True)