给定 pytorch 中的 idx 列表,如何从数据集中获取一批样本?
How can I get a batch of samples from a dataset given a list of idxs in pytorch?
我有一个 torch.utils.data.Dataset
对象,我想要一个 DataLoader
或类似的对象,它接受一个 idx 列表和 returns 一批具有相应 idx 的样本.
例子,我有
list_idxs = [10, 109, 7, 12]
我想这样做:
batch = loader.getbatch(list_idxs)
其中批次包含:
[sample10, sample109, sample7, sample12]
是否有一种简单而优雅的优化方式来做到这一点?
如果我对你的问题的理解正确,你可以有一个 DataLoader
return 使用自定义 batch_sampler
的一系列手工选择的批次(你甚至不需要通过在这种情况下是 sampler
)。
任意给定Dataset
:
>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
... def __getitem__(self, idx):
... return idx
然后您可以定义如下内容:
>>> class MyBatchSampler(Sampler):
... def __init__(self, batches):
... self.batches = batches
...
... def __iter__(self):
... for batch in self.batches:
... yield batch
...
... def __len__(self):
... return len(self.batches)
它只需要一个列表列表,其中包含要包含在每个批次中的数据集索引。
然后:
>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
... print(batch)
...
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])
应该很容易扩展到您的实际数据集等
我有一个 torch.utils.data.Dataset
对象,我想要一个 DataLoader
或类似的对象,它接受一个 idx 列表和 returns 一批具有相应 idx 的样本.
例子,我有
list_idxs = [10, 109, 7, 12]
我想这样做:
batch = loader.getbatch(list_idxs)
其中批次包含:
[sample10, sample109, sample7, sample12]
是否有一种简单而优雅的优化方式来做到这一点?
如果我对你的问题的理解正确,你可以有一个 DataLoader
return 使用自定义 batch_sampler
的一系列手工选择的批次(你甚至不需要通过在这种情况下是 sampler
)。
任意给定Dataset
:
>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
... def __getitem__(self, idx):
... return idx
然后您可以定义如下内容:
>>> class MyBatchSampler(Sampler):
... def __init__(self, batches):
... self.batches = batches
...
... def __iter__(self):
... for batch in self.batches:
... yield batch
...
... def __len__(self):
... return len(self.batches)
它只需要一个列表列表,其中包含要包含在每个批次中的数据集索引。
然后:
>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
... print(batch)
...
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])
应该很容易扩展到您的实际数据集等