PyTorch - 丢弃数据加载器批处理

PyTorch - discard dataloader batch

我有一个自定义 Dataset 从大文件加载数据。有时,加载的数据是空的,我不想用它们来训练。

Dataset我有:

   def __getitem__(self, i):
       (x, y) = self.getData(i) #getData loads data and handles problems     
       return (x, y)

如果数据不正确 return (None, None)xy 都是 None)。但是,它后来在 DataLoader 中失败了,我无法完全跳过这批。我将批量大小设置为 1.

trainLoader = DataLoader(trainDataset, batch_size=1, shuffle=False)
for x_batch, y_batch in trainLoader:
    #process and train

您可以实施自定义 IterableDataset 并定义 __next____iter__ 以跳过您的 getData 函数引发错误的任何实例:

这是一个可能的虚拟数据实现:

class DS(IterableDataset):
    def __init__(self):
        self.data = torch.randint(0,3,(20,))
        self._i = -1

    def getData(self, index):
        x = self.data[index]
        if x == 0:
            raise ValueError
        return x

    def __iter__(self):
        return self

    def __next__(self):
        self._i += 1
        if self._i == len(self.data):  # out of instances
            self._i = -1               # reset the iterable
            raise StopIteration        # stop the iteration
        try:
            return self.getData(self._i)
        except ValueError:
            return next(self)

你会像这样使用它:

>>> trainLoader = DataLoader(DS(), batch_size=1, shuffle=False)
>>> for x in trainLoader:
...    print(x)
tensor([1])
tensor([2])
tensor([2])
...
tensor([1])
tensor([1])

此处所有 0 个实例已在可迭代数据集中被跳过。

您可以调整这个简单示例以满足您的需要。