如何在 Dataloader 中使用 Batchsampler
How to use a Batchsampler within a Dataloader
我需要在 pytorch DataLoader
中使用 BatchSampler
而不是多次调用数据集的 __getitem__
(远程数据集,每个查询都很昂贵)。
我无法理解如何将 batchsampler 与 any 给定数据集一起使用。
例如
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, idx):
return self.ddf[idx] --------> This is as expensive as a batch call
def get_batch(self, batch_idx):
return self.ddf[batch_idx]
my_loader = DataLoader(MyDataset(remote_ddf),
batch_sampler=BatchSampler(Sampler(), batch_size=3))
我不明白的是我如何使用我的 get_batch
函数而不是 __getitem__ 函数,既没有在网上也没有在火炬文档中找到任何示例。
编辑:
按照 Szymon Maszke 的回答,这是我尝试过的方法,但是 \_\_get_item__
每次调用都会得到一个索引,而不是大小为 batch_size
的列表
class Dataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, batch_idx): ------> here I get only one index
return self.wiki_df.loc[batch_idx]
loader = DataLoader(
dataset=dataset,
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
num_workers=self.hparams.num_data_workers,
)
您不能使用 get_batch
而不是 __getitem__
,而且我认为这样做没有意义。
torch.utils.data.BatchSampler
takes indices from your Sampler()
instance (in this case 3
of them) and returns it as list
so those can be used in your MyDataset
__getitem__
method (check source code,大多数采样器和数据相关的实用程序都很容易理解,以备不时之需。
我假设您的 self.ddf
支持列表切片(例如 self.ddf[[25, 44, 115]]
returns 值正确并且只使用一次昂贵的调用)。在这种情况下,只需将 get_batch
切换为 __getitem__
就可以了。
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, batch_idx):
return self.ddf[batch_idx] -> batch_idx is a list
编辑: 您必须将 batch_sampler
指定为 sampler
,否则批次将被分成单个索引。这应该没问题:
loader = DataLoader(
dataset=dataset,
# This line below!
sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
),
num_workers=self.hparams.num_data_workers,
)
我需要在 pytorch DataLoader
中使用 BatchSampler
而不是多次调用数据集的 __getitem__
(远程数据集,每个查询都很昂贵)。
我无法理解如何将 batchsampler 与 any 给定数据集一起使用。
例如
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, idx):
return self.ddf[idx] --------> This is as expensive as a batch call
def get_batch(self, batch_idx):
return self.ddf[batch_idx]
my_loader = DataLoader(MyDataset(remote_ddf),
batch_sampler=BatchSampler(Sampler(), batch_size=3))
我不明白的是我如何使用我的 get_batch
函数而不是 __getitem__ 函数,既没有在网上也没有在火炬文档中找到任何示例。
编辑:
按照 Szymon Maszke 的回答,这是我尝试过的方法,但是 \_\_get_item__
每次调用都会得到一个索引,而不是大小为 batch_size
class Dataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, batch_idx): ------> here I get only one index
return self.wiki_df.loc[batch_idx]
loader = DataLoader(
dataset=dataset,
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
num_workers=self.hparams.num_data_workers,
)
您不能使用 get_batch
而不是 __getitem__
,而且我认为这样做没有意义。
torch.utils.data.BatchSampler
takes indices from your Sampler()
instance (in this case 3
of them) and returns it as list
so those can be used in your MyDataset
__getitem__
method (check source code,大多数采样器和数据相关的实用程序都很容易理解,以备不时之需。
我假设您的 self.ddf
支持列表切片(例如 self.ddf[[25, 44, 115]]
returns 值正确并且只使用一次昂贵的调用)。在这种情况下,只需将 get_batch
切换为 __getitem__
就可以了。
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, batch_idx):
return self.ddf[batch_idx] -> batch_idx is a list
编辑: 您必须将 batch_sampler
指定为 sampler
,否则批次将被分成单个索引。这应该没问题:
loader = DataLoader(
dataset=dataset,
# This line below!
sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
),
num_workers=self.hparams.num_data_workers,
)