PyTorch - 丢弃数据加载器批处理
PyTorch - discard dataloader batch
我有一个自定义 Dataset
从大文件加载数据。有时,加载的数据是空的,我不想用它们来训练。
在Dataset
我有:
def __getitem__(self, i):
(x, y) = self.getData(i) #getData loads data and handles problems
return (x, y)
如果数据不正确 return (None, None)
(x
和 y
都是 None
)。但是,它后来在 DataLoader
中失败了,我无法完全跳过这批。我将批量大小设置为 1
.
trainLoader = DataLoader(trainDataset, batch_size=1, shuffle=False)
for x_batch, y_batch in trainLoader:
#process and train
您可以实施自定义 IterableDataset
并定义 __next__
和 __iter__
以跳过您的 getData
函数引发错误的任何实例:
这是一个可能的虚拟数据实现:
class DS(IterableDataset):
def __init__(self):
self.data = torch.randint(0,3,(20,))
self._i = -1
def getData(self, index):
x = self.data[index]
if x == 0:
raise ValueError
return x
def __iter__(self):
return self
def __next__(self):
self._i += 1
if self._i == len(self.data): # out of instances
self._i = -1 # reset the iterable
raise StopIteration # stop the iteration
try:
return self.getData(self._i)
except ValueError:
return next(self)
你会像这样使用它:
>>> trainLoader = DataLoader(DS(), batch_size=1, shuffle=False)
>>> for x in trainLoader:
... print(x)
tensor([1])
tensor([2])
tensor([2])
...
tensor([1])
tensor([1])
此处所有 0
个实例已在可迭代数据集中被跳过。
您可以调整这个简单示例以满足您的需要。
我有一个自定义 Dataset
从大文件加载数据。有时,加载的数据是空的,我不想用它们来训练。
在Dataset
我有:
def __getitem__(self, i):
(x, y) = self.getData(i) #getData loads data and handles problems
return (x, y)
如果数据不正确 return (None, None)
(x
和 y
都是 None
)。但是,它后来在 DataLoader
中失败了,我无法完全跳过这批。我将批量大小设置为 1
.
trainLoader = DataLoader(trainDataset, batch_size=1, shuffle=False)
for x_batch, y_batch in trainLoader:
#process and train
您可以实施自定义 IterableDataset
并定义 __next__
和 __iter__
以跳过您的 getData
函数引发错误的任何实例:
这是一个可能的虚拟数据实现:
class DS(IterableDataset):
def __init__(self):
self.data = torch.randint(0,3,(20,))
self._i = -1
def getData(self, index):
x = self.data[index]
if x == 0:
raise ValueError
return x
def __iter__(self):
return self
def __next__(self):
self._i += 1
if self._i == len(self.data): # out of instances
self._i = -1 # reset the iterable
raise StopIteration # stop the iteration
try:
return self.getData(self._i)
except ValueError:
return next(self)
你会像这样使用它:
>>> trainLoader = DataLoader(DS(), batch_size=1, shuffle=False)
>>> for x in trainLoader:
... print(x)
tensor([1])
tensor([2])
tensor([2])
...
tensor([1])
tensor([1])
此处所有 0
个实例已在可迭代数据集中被跳过。
您可以调整这个简单示例以满足您的需要。