PyTorch:自定义批量采样器在第一个纪元后耗尽

PyTorch: Custom batch sampler exhausts after first epoch

我使用带有自定义 batch_samplerDataLoader 来确保每个批次 class 平衡。如何防止迭代器在第一个纪元耗尽自身?

import torch

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.x = torch.rand(10, 10)
        self.y = torch.Tensor([0] * 5 + [1] * 5)
        
    def __len__(self):
        len(self.y)
        
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return iter(batch_idx)

def train(loader):
    for epoch in range(10):
        for batch, (x, y) in enumerate(loader):
            print('epoch:', epoch, 'batch:', batch) # stops after first epoch

if __name__=='__main__':
    my_dataset = CustomDataset()
    my_loader = torch.utils.data.DataLoader(
        dataset=my_dataset,
        batch_sampler=custom_batch_sampler()
    )
    train(my_loader)

训练在第一个纪元后停止,next(iter(loader)) 给出 StopIteration 错误。

epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4

自定义批量采样器需要是 Sampler 或一些可迭代的。在每个时期,都会从这个可迭代对象中生成一个新的迭代器。这意味着您实际上不需要手动制作迭代器(它将 运行 出来并在第一个纪元后提高 StopIteration ),但您可以只提供您的列表,因此如果您删除它应该可以工作iter():

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return batch_idx