Pytorch DataLoader - 选择 Class STL10 数据集
Pytorch DataLoader - Choose Class STL10 Dataset
是否可以只在 PyTorch torchvision
的 STL10 数据集中拉取 class = 0 的位置?我可以循环检查它们,但需要接收成批的 class 0 张图片
# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
]),
split='train',
download=True)
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
for i, (images, labels) in enumerate(train_loader):
if labels[0] == 0:...
根据 iacolippo 的回答进行编辑 - 现在有效:
# Set params
batch_size = 25
label_class = 0 # only airplane images
# Return only images of certain class (eg. airplanes = class 0)
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
]),
split='train',
download=True)
# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
如果您只想要来自一个 class 的样本,您可以从 Dataset
实例中获取具有相同 class 的样本的索引,例如
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
然后您可以使用 SubsetRandomSampler
仅从一个 class
的索引列表中抽取样本
torch.utils.data.sampler.SubsetRandomSampler(indices)
是否可以只在 PyTorch torchvision
的 STL10 数据集中拉取 class = 0 的位置?我可以循环检查它们,但需要接收成批的 class 0 张图片
# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
]),
split='train',
download=True)
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
for i, (images, labels) in enumerate(train_loader):
if labels[0] == 0:...
根据 iacolippo 的回答进行编辑 - 现在有效:
# Set params
batch_size = 25
label_class = 0 # only airplane images
# Return only images of certain class (eg. airplanes = class 0)
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
]),
split='train',
download=True)
# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
如果您只想要来自一个 class 的样本,您可以从 Dataset
实例中获取具有相同 class 的样本的索引,例如
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
然后您可以使用 SubsetRandomSampler
仅从一个 class
torch.utils.data.sampler.SubsetRandomSampler(indices)