PyTorch:自定义批量采样器在第一个纪元后耗尽
PyTorch: Custom batch sampler exhausts after first epoch
我使用带有自定义 batch_sampler
的 DataLoader
来确保每个批次 class 平衡。如何防止迭代器在第一个纪元耗尽自身?
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
self.x = torch.rand(10, 10)
self.y = torch.Tensor([0] * 5 + [1] * 5)
def __len__(self):
len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return iter(batch_idx)
def train(loader):
for epoch in range(10):
for batch, (x, y) in enumerate(loader):
print('epoch:', epoch, 'batch:', batch) # stops after first epoch
if __name__=='__main__':
my_dataset = CustomDataset()
my_loader = torch.utils.data.DataLoader(
dataset=my_dataset,
batch_sampler=custom_batch_sampler()
)
train(my_loader)
训练在第一个纪元后停止,next(iter(loader))
给出 StopIteration
错误。
epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4
自定义批量采样器需要是 Sampler
或一些可迭代的。在每个时期,都会从这个可迭代对象中生成一个新的迭代器。这意味着您实际上不需要手动制作迭代器(它将 运行 出来并在第一个纪元后提高 StopIteration
),但您可以只提供您的列表,因此如果您删除它应该可以工作iter()
:
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return batch_idx
我使用带有自定义 batch_sampler
的 DataLoader
来确保每个批次 class 平衡。如何防止迭代器在第一个纪元耗尽自身?
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
self.x = torch.rand(10, 10)
self.y = torch.Tensor([0] * 5 + [1] * 5)
def __len__(self):
len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return iter(batch_idx)
def train(loader):
for epoch in range(10):
for batch, (x, y) in enumerate(loader):
print('epoch:', epoch, 'batch:', batch) # stops after first epoch
if __name__=='__main__':
my_dataset = CustomDataset()
my_loader = torch.utils.data.DataLoader(
dataset=my_dataset,
batch_sampler=custom_batch_sampler()
)
train(my_loader)
训练在第一个纪元后停止,next(iter(loader))
给出 StopIteration
错误。
epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4
自定义批量采样器需要是 Sampler
或一些可迭代的。在每个时期,都会从这个可迭代对象中生成一个新的迭代器。这意味着您实际上不需要手动制作迭代器(它将 运行 出来并在第一个纪元后提高 StopIteration
),但您可以只提供您的列表,因此如果您删除它应该可以工作iter()
:
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return batch_idx