给定 pytorch 中的 idx 列表,如何从数据集中获取一批样本?

How can I get a batch of samples from a dataset given a list of idxs in pytorch?

我有一个 torch.utils.data.Dataset 对象,我想要一个 DataLoader 或类似的对象,它接受一个 idx 列表和 returns 一批具有相应 idx 的样本.

例子,我有

list_idxs = [10, 109, 7, 12]

我想这样做:

batch = loader.getbatch(list_idxs)

其中批次包含:

[sample10, sample109, sample7, sample12]

是否有一种简单而优雅的优化方式来做到这一点?

如果我对你的问题的理解正确,你可以有一个 DataLoader return 使用自定义 batch_sampler 的一系列手工选择的批次(你甚至不需要通过在这种情况下是 sampler)。

任意给定Dataset:

>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
...     def __getitem__(self, idx):
...         return idx

然后您可以定义如下内容:

>>> class MyBatchSampler(Sampler):
...     def __init__(self, batches):
...         self.batches = batches
...
...     def __iter__(self):
...         for batch in self.batches:
...             yield batch
...
...     def __len__(self):
...         return len(self.batches)

它只需要一个列表列表,其中包含要包含在每个批次中的数据集索引。

然后:

>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
...     print(batch)
... 
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])

应该很容易扩展到您的实际数据集等