pytorch 数据集中每个 class 的实例数
Number of instances per class in pytorch dataset
我正在尝试使用 PyTorch 制作一个简单的图像 classifier。
这就是我将数据加载到数据集和 dataLoader 中的方式:
batch_size = 64
validation_split = 0.2
data_dir = PROJECT_PATH+"/categorized_products"
transform = transforms.Compose([transforms.Grayscale(), CustomToTensor()])
dataset = ImageFolder(data_dir, transform=transform)
indices = list(range(len(dataset)))
train_indices = indices[:int(len(indices)*0.8)]
test_indices = indices[int(len(indices)*0.8):]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)
我想在训练和测试数据中分别打印出每个 class 中的图像数量,如下所示:
火车数据中:
- 鞋子:20
- 衬衫:14
在测试数据中:
- 鞋子:4
- 衬衫:3
我试过这个:
from collections import Counter
print(dict(Counter(sample_tup[1] for sample_tup in dataset.imgs)))
但是我得到了这个错误:
AttributeError: 'MyDataset' object has no attribute 'img'
您需要使用 .targets
来访问数据标签,即
print(dict(Counter(dataset.targets)))
它将打印类似这样的内容(例如在 MNIST 数据集中):
{5: 5421, 0: 5923, 4: 5842, 1: 6742, 9: 5949, 2: 5958, 3: 6131, 6: 5918, 7: 6265, 8: 5851}
此外,您可以使用 .classes
或 .class_to_idx
获取标签 ID 到 classes:
的映射
print(dataset.class_to_idx)
{'0 - zero': 0,
'1 - one': 1,
'2 - two': 2,
'3 - three': 3,
'4 - four': 4,
'5 - five': 5,
'6 - six': 6,
'7 - seven': 7,
'8 - eight': 8,
'9 - nine': 9}
编辑:方法 1
根据评论,为了分别获得 class 训练集和测试集的分布,您可以简单地迭代子集,如下所示:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# labels in training set
train_classes = [label for _, label in train_dataset]
Counter(train_classes)
Counter({0: 4757,
1: 5363,
2: 4782,
3: 4874,
4: 4678,
5: 4321,
6: 4747,
7: 5024,
8: 4684,
9: 4770})
编辑(2):方法2
既然你有一个大数据集,而且正如你所说,迭代所有训练集需要相当长的时间,还有另一种方法:
您可以使用.indices
of subset,它指的是为子集选择的原始数据集中的索引。
即
train_classes = [dataset.targets[i] for i in train_dataset.indices]
Counter(train_classes) # if doesn' work: Counter(i.item() for i in train_classes)
简单易行
如果你有 dataset
class 在你的情况下 ImageFolder
dataset = MyDataset() # which in your case in ImageFolder
labels = torch.zeros(num_classes, dtype=torch.long)
for _, target in dataset:
labels += target
我正在尝试使用 PyTorch 制作一个简单的图像 classifier。 这就是我将数据加载到数据集和 dataLoader 中的方式:
batch_size = 64
validation_split = 0.2
data_dir = PROJECT_PATH+"/categorized_products"
transform = transforms.Compose([transforms.Grayscale(), CustomToTensor()])
dataset = ImageFolder(data_dir, transform=transform)
indices = list(range(len(dataset)))
train_indices = indices[:int(len(indices)*0.8)]
test_indices = indices[int(len(indices)*0.8):]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)
我想在训练和测试数据中分别打印出每个 class 中的图像数量,如下所示:
火车数据中:
- 鞋子:20
- 衬衫:14
在测试数据中:
- 鞋子:4
- 衬衫:3
我试过这个:
from collections import Counter
print(dict(Counter(sample_tup[1] for sample_tup in dataset.imgs)))
但是我得到了这个错误:
AttributeError: 'MyDataset' object has no attribute 'img'
您需要使用 .targets
来访问数据标签,即
print(dict(Counter(dataset.targets)))
它将打印类似这样的内容(例如在 MNIST 数据集中):
{5: 5421, 0: 5923, 4: 5842, 1: 6742, 9: 5949, 2: 5958, 3: 6131, 6: 5918, 7: 6265, 8: 5851}
此外,您可以使用 .classes
或 .class_to_idx
获取标签 ID 到 classes:
print(dataset.class_to_idx)
{'0 - zero': 0,
'1 - one': 1,
'2 - two': 2,
'3 - three': 3,
'4 - four': 4,
'5 - five': 5,
'6 - six': 6,
'7 - seven': 7,
'8 - eight': 8,
'9 - nine': 9}
编辑:方法 1
根据评论,为了分别获得 class 训练集和测试集的分布,您可以简单地迭代子集,如下所示:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# labels in training set
train_classes = [label for _, label in train_dataset]
Counter(train_classes)
Counter({0: 4757,
1: 5363,
2: 4782,
3: 4874,
4: 4678,
5: 4321,
6: 4747,
7: 5024,
8: 4684,
9: 4770})
编辑(2):方法2
既然你有一个大数据集,而且正如你所说,迭代所有训练集需要相当长的时间,还有另一种方法:
您可以使用.indices
of subset,它指的是为子集选择的原始数据集中的索引。
即
train_classes = [dataset.targets[i] for i in train_dataset.indices]
Counter(train_classes) # if doesn' work: Counter(i.item() for i in train_classes)
简单易行
如果你有 dataset
class 在你的情况下 ImageFolder
dataset = MyDataset() # which in your case in ImageFolder
labels = torch.zeros(num_classes, dtype=torch.long)
for _, target in dataset:
labels += target