__getitem__ 的 idx 如何在 PyTorch 的 DataLoader 中工作?
How does the __getitem__'s idx work within PyTorch's DataLoader?
我目前正在尝试使用 PyTorch 的 DataLoader 来处理数据以输入我的深度学习模型,但遇到了一些困难。
我需要的数据的形状是(minibatch_size=32, rows=100, columns=41)
。我编写的自定义 Dataset
class 中的 __getitem__
代码看起来像这样:
def __getitem__(self, idx):
x = np.array(self.train.iloc[idx:100, :])
return x
我这样写的原因是因为我希望 DataLoader 一次处理形状 (100, 41)
的输入实例,而我们有 32 个这样的单个实例。
但是,我注意到与我最初的看法相反,DataLoader 传递给函数的 idx
参数不是顺序的(这很重要,因为我的数据是时间序列数据)。例如,打印值给了我这样的东西:
idx = 206000
idx = 113814
idx = 80597
idx = 3836
idx = 156187
idx = 54990
idx = 8694
idx = 190555
idx = 84418
idx = 161773
idx = 177725
idx = 178351
idx = 89217
idx = 11048
idx = 135994
idx = 15067
这是正常行为吗?我发布这个问题是因为返回的数据批次不是我最初想要的。
我在使用 DataLoader 之前用来预处理数据的原始逻辑是:
- 从
txt
或 csv
文件中读取数据。
- 计算数据中有多少批次并相应地对数据进行切片。例如,由于一个输入实例的形状为
(100, 41)
,其中 32 个实例形成一个小批量,我们通常最终会得到大约 100 个左右的批量,并相应地重新整形数据。
- 一个输入的形状为
(32, 100, 41)
。
我不确定我还应该如何处理 DataLoader 挂钩方法。非常感谢任何提示或建议。提前致谢。
定义 idx 的是 sampler
或 batch_sampler
,如您所见here (open-source projects are your friend). In this code (and comment/docstring) you can see the difference between sampler
and batch_sampler
. If you look here 您将看到如何选择索引:
def __next__(self):
index = self._next_index()
# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
return next(self._sampler_iter)
# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)
# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler
注意这是_SingleProcessDataLoaderIter
的实现;你可以找到 _MultiProcessingDataLoaderIter
here (ofc, which one is used depends on the num_workers
value, as you can see here). Going back to the samplers, assuming your Dataset is not _DatasetKind.Iterable
and that you are not providing a custom sampler, it means you are either using (dataloader.py#L212-L215):
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
我们来看看how the default BatchSampler builds a batch:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
非常简单:它从采样器获取索引,直到达到所需的 batch_size。
现在可以通过查看每个默认采样器的工作原理来回答问题 "How does the __getitem__'s idx work within PyTorch's DataLoader?"。
- SequentialSampler(这是完整的实现——非常简单,不是吗?):
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
- RandomSampler(让我们只看
__iter__
实现):
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
因此,由于您没有提供任何代码,我们只能假设:
- 您在 DataLoader 或
中使用 shuffle=True
- 您正在使用自定义采样器或
- 您的数据集是
_DatasetKind.Iterable
我目前正在尝试使用 PyTorch 的 DataLoader 来处理数据以输入我的深度学习模型,但遇到了一些困难。
我需要的数据的形状是(minibatch_size=32, rows=100, columns=41)
。我编写的自定义 Dataset
class 中的 __getitem__
代码看起来像这样:
def __getitem__(self, idx):
x = np.array(self.train.iloc[idx:100, :])
return x
我这样写的原因是因为我希望 DataLoader 一次处理形状 (100, 41)
的输入实例,而我们有 32 个这样的单个实例。
但是,我注意到与我最初的看法相反,DataLoader 传递给函数的 idx
参数不是顺序的(这很重要,因为我的数据是时间序列数据)。例如,打印值给了我这样的东西:
idx = 206000
idx = 113814
idx = 80597
idx = 3836
idx = 156187
idx = 54990
idx = 8694
idx = 190555
idx = 84418
idx = 161773
idx = 177725
idx = 178351
idx = 89217
idx = 11048
idx = 135994
idx = 15067
这是正常行为吗?我发布这个问题是因为返回的数据批次不是我最初想要的。
我在使用 DataLoader 之前用来预处理数据的原始逻辑是:
- 从
txt
或csv
文件中读取数据。 - 计算数据中有多少批次并相应地对数据进行切片。例如,由于一个输入实例的形状为
(100, 41)
,其中 32 个实例形成一个小批量,我们通常最终会得到大约 100 个左右的批量,并相应地重新整形数据。 - 一个输入的形状为
(32, 100, 41)
。
我不确定我还应该如何处理 DataLoader 挂钩方法。非常感谢任何提示或建议。提前致谢。
定义 idx 的是 sampler
或 batch_sampler
,如您所见here (open-source projects are your friend). In this code (and comment/docstring) you can see the difference between sampler
and batch_sampler
. If you look here 您将看到如何选择索引:
def __next__(self):
index = self._next_index()
# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
return next(self._sampler_iter)
# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)
# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler
注意这是_SingleProcessDataLoaderIter
的实现;你可以找到 _MultiProcessingDataLoaderIter
here (ofc, which one is used depends on the num_workers
value, as you can see here). Going back to the samplers, assuming your Dataset is not _DatasetKind.Iterable
and that you are not providing a custom sampler, it means you are either using (dataloader.py#L212-L215):
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
我们来看看how the default BatchSampler builds a batch:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
非常简单:它从采样器获取索引,直到达到所需的 batch_size。
现在可以通过查看每个默认采样器的工作原理来回答问题 "How does the __getitem__'s idx work within PyTorch's DataLoader?"。
- SequentialSampler(这是完整的实现——非常简单,不是吗?):
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
- RandomSampler(让我们只看
__iter__
实现):
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
因此,由于您没有提供任何代码,我们只能假设:
- 您在 DataLoader 或 中使用
- 您正在使用自定义采样器或
- 您的数据集是
_DatasetKind.Iterable
shuffle=True