PyTorch random_split() 返回错误大小的加载器

PyTorch random_split() is returning wrong sized loader

我的数据集有一个自定义数据集加载器。我想将数据集拆分为 70% 的训练数据、20% 的验证数据和 10% 的测试数据。我有 16,488 条数据。所以,我的火车数据应该是 11,542。但它变成了 770 个训练数据、220 个验证数据和 110 个测试数据。我已经试过了,但无法找出问题所在。

class Dataset(Dataset):
    def __init__(self, directory, transform, preload=False, device: torch.device = torch.device('cpu'), **kwargs):
        self.device = device
        self.directory = directory
        self.transform = transform
        self.labels = []
        self.images = []
        self.preload = preload

        for i, file in enumerate(os.listdir(self.directory)):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)
            
            if file_labels is None:
                continue
                
            if self.preload:
                image = Image.open(os.path.join(self.directory, file)).convert('RGB')
                if self.transform is not None:
                    image = self.transform(image).to(self.device)
            else:
                image = os.path.join(self.directory, file)

            self.images.append(image)
            
            gender_to_class_id = {
                'm': 0, 
                'f': 1
            }
            gender = gender_to_class_id[file_labels['gender']]
            age = int(file_labels['age'])
            self.labels.append({
                'age': age,
                'gender': gender
            })
        pass

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.images[idx]

        if not self.preload:
            image = Image.open(image).convert('RGB')
            if self.transform is not None:
                image = self.transform(image).to(self.device)

        labels = {
            'age': self.labels[idx]['age'], 
            'gender': self.labels[idx]['gender'],
        }
        return image.to(self.device), labels
    
    def get_loaders(self, transform, train_size=0.7, validate_size=0.2, test_size=0.1, batch_size=15, **kwargs):
        if round(train_size + validate_size + test_size, 1) > 1.0:
            sys.exit("Sum of the percentages should be less than 1. it's " + str(
                train_size + validate_size + test_size) + " now!")

        train_len = int(len(self) * train_size)
        validate_len = int(len(self) * validate_size)
        test_len = int(len(self) * test_size)
        others_len = len(self) - train_len - validate_len - test_len

        self.trainDataset, self.validateDataset, self.testDataset, _ = torch.utils.data.random_split(
            self, [train_len, validate_len, test_len, others_len]
        )

        train_loader = DataLoader(self.trainDataset, batch_size=batch_size)
        validate_loader = DataLoader(self.validateDataset, batch_size=batch_size)
        test_loader = DataLoader(self.testDataset, batch_size=batch_size)

        return train_loader, validate_loader, test_loader

看来你在给

batch_size=15

由于数据加载器是可迭代的,它可能只是为您提供第一批的 len()。 它还解释了为什么您得到列车数据 = 770,而它应该是 11,542。因为,

16488 / 15 * 0.7 = 769.44 ≈ 770

分配 batch_size = 1 应该可以解决问题。

16488 / 1 * 0.7 = 11541.6 ≈ 11542