'DataLoader' 对象不支持索引
'DataLoader' object does not support indexing
我已经通过这个 pytorch 下载了 ImageNet 数据集 api 通过设置 download=True。但是我无法遍历数据加载器。
错误提示“'DataLoader' 对象不支持索引”
trainset = torch.utils.data.DataLoader(
datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)
我尝试了一种简单的方法,我只是尝试 运行 以下,
trainloader[0]
在根目录下,模式是
root/
train/
n01440764/
n01443537/
n01443537_2.jpg
官网文档没说别的。 https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
我做错了什么?
torch.utils.data.DataLoader()
的输入数据集应该是 torch.utils.data.Dataset
类型,而不是 torch.utils.data.DataLoader
类型,这正是您在上面的代码中所做的。
所以,你上面的代码应该是:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train',
download=False)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=1,
shuffle=False,
num_workers=1)
有关更多详细信息,请查看官方火炬文档here。
嗯,答案很简单(除了另一个答案中提到的错误)。
DataLoader
没有 __getitem__
方法(请自行查看 in the source code)。
它用于对数据(或成批数据)进行迭代,而不是随机访问。如果你想访问特定的元素,你应该使用 torch.utils.data.Dataset
,在你的情况下:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]
获取一批
如果你想得到一个批次,你可以迭代它然后中断:
for batch in dataloader:
print(batch) # or anything else you want to do
break
DataLoader
以默认或指定的方式创建随机索引(参见 samplers),因此没有 __getitem__
因为它对这个对象没有意义。
您也可以从 DataLoader
继承并创建自己的 __getitem__
函数来执行您想要的操作(虽然更复杂)。
完整示例
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)
for batch in trainloader:
print(batch)
break
上面应该打印第一批里面的东西。
解决方案
input_transform = standard_transforms.Compose([
transforms.Resize((255,255)), # to Make sure all the
transforms.CenterCrop(224), # imgs are at the same size
transforms.ToTensor()
])
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)
for batch_idx, data in enumerate(trainloader, 0):
x, y = data
break
我已经通过这个 pytorch 下载了 ImageNet 数据集 api 通过设置 download=True。但是我无法遍历数据加载器。
错误提示“'DataLoader' 对象不支持索引”
trainset = torch.utils.data.DataLoader(
datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)
我尝试了一种简单的方法,我只是尝试 运行 以下,
trainloader[0]
在根目录下,模式是
root/
train/
n01440764/
n01443537/
n01443537_2.jpg
官网文档没说别的。 https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
我做错了什么?
torch.utils.data.DataLoader()
的输入数据集应该是 torch.utils.data.Dataset
类型,而不是 torch.utils.data.DataLoader
类型,这正是您在上面的代码中所做的。
所以,你上面的代码应该是:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train',
download=False)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=1,
shuffle=False,
num_workers=1)
有关更多详细信息,请查看官方火炬文档here。
嗯,答案很简单(除了另一个答案中提到的错误)。
DataLoader
没有 __getitem__
方法(请自行查看 in the source code)。
它用于对数据(或成批数据)进行迭代,而不是随机访问。如果你想访问特定的元素,你应该使用 torch.utils.data.Dataset
,在你的情况下:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]
获取一批
如果你想得到一个批次,你可以迭代它然后中断:
for batch in dataloader:
print(batch) # or anything else you want to do
break
DataLoader
以默认或指定的方式创建随机索引(参见 samplers),因此没有 __getitem__
因为它对这个对象没有意义。
您也可以从 DataLoader
继承并创建自己的 __getitem__
函数来执行您想要的操作(虽然更复杂)。
完整示例
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)
for batch in trainloader:
print(batch)
break
上面应该打印第一批里面的东西。
解决方案
input_transform = standard_transforms.Compose([
transforms.Resize((255,255)), # to Make sure all the
transforms.CenterCrop(224), # imgs are at the same size
transforms.ToTensor()
])
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)
for batch_idx, data in enumerate(trainloader, 0):
x, y = data
break