如何使用 PyTorch DataLoader 创建批次,以便给定批次中的每个示例都具有相同的属性值?
How to create batches using PyTorch DataLoader such that each example in a given batch has the same value for an attribute?
假设我有一个列表,datalist
,其中包含几个示例(对于我的用例,它们的类型为 torch_geometric.data.Data
)。每个示例都有一个属性 num_nodes
出于演示目的,可以使用以下代码片段
创建这样的 datalist
import torch
from torch_geometric.data import Data # each example is of this type
import networkx as nx # for creating random data
import numpy as np
# the python list containing the examples
datalist = []
for num_node in [9, 11]:
for _ in range(1024):
edge_index = torch.from_numpy(
np.array(nx.fast_gnp_random_graph(num_node, 0.5).edges())
).t().contiguous()
datalist.append(
Data(
x=torch.rand(num_node, 5),
edge_index=edge_index,
edge_attr=torch.rand(edge_index.size(1))
)
)
从上面的 datalist
对象,我可以通过使用 DataLoader constructor 为:
from torch_geometric.loader import DataLoader
dataloader = DataLoader(
datalist, batch_size=128, shuffle=True
)
我的问题是,如何使用 DataLoader
class 来确保给定批次中的每个示例都具有相同的 num_nodes
属性值?
PS:
我试图解决它并通过使用 中的 combine_iterators
函数片段组合多个 DataLoader
对象提出了一个 hacky 解决方案,如下所示:
def get_combined_iterator(*iterables):
nexts = [iter(iterable).__next__ for iterable in iterables]
while nexts:
next = random.choice(nexts)
try:
yield next()
except StopIteration:
nexts.remove(next)
datalists = defaultdict(list)
for data in datalist:
datalists[data.num_nodes].append(data)
dataloaders = (
DataLoader(data, batch_size=128, shuffle=True) for data in datalists.values()
)
batches = get_combined_iterator(*dataloaders)
但是,我认为一定有一些 elegant/better 方法可以做到这一点,因此出现了这个问题。
如果您的基础数据集是 map-style,您可以定义一个 torch.utils.data.Sampler
,其中 returns 您想要一起批处理的示例的索引。它的一个实例将作为 batch_sampler
kwarg 传递给您的 DataLoader
并且您可以删除 batch_size
kwarg,因为采样器将根据您的实现方式为您形成批次。
按照 erip's suggestion, I subclassed torch.utils.data.sampler.Sampler
创建一个新的采样器:BucketSampler
它使用 torch.utils.data.sampler.SubsetRandomSampler
和 torch.utils.data.sampler.BatchSampler
来实现对给定属性具有相同值的示例进行批处理。
import torch
from torch.utils.data.sampler import Sampler, BatchSampler, SubsetRandomSampler
class BucketSampler(Sampler):
def __init__(self, dataset, batch_size, start_pos_data, generator=None) -> None:
self.dataset = dataset
self.batch_size = batch_size
self.generator = generator
start_pos_data = start_pos_data
start_end_indices = []
for i in range(len(start_pos_data) - 1):
start_end_indices.append((start_pos_data[i], start_pos_data[i+1]))
start_end_indices.append((start_pos_data[-1], len(self.dataset)))
ranges = [range(start, end) for start, end in start_end_indices]
subset_samplers = [SubsetRandomSampler(range_, generator=generator) for range_ in ranges]
self.samplers = [
BatchSampler(subset_sampler, batch_size, drop_last=False) for subset_sampler in subset_samplers
]
self._len = 0
for sampler in self.samplers:
self._len += len(sampler)
def __iter__(self):
iterators = [iter(sampler) for sampler in self.samplers]
while iterators:
randint = torch.randint(0, len(iterators),size=(1,), generator=self.generator)[0]
try:
yield next(iterators[randint])
except StopIteration:
iterators.pop(randint)
def __len__(self):
return self._len
除了通常的参数之外,这个 class 也将 start_pos_data
作为参数,它是一个列表,其中包含 datalist
中的第一个索引(来自问题中给出的示例的数据集) 属性值更改的位置。因此,对于上面的示例,我们可以借助以下代码片段创建这样一个列表:
# sort datalist to ensure that data items with the same number of nodes are grouped together
sorted_datalist = sorted(datalist, key = lambda data: data.num_nodes)
# initialize the start_pos_data by 0
start_pos_data = [0]
for i in range(1,len(sorted_datalist)):
if sorted_datalist[i].num_nodes != sorted_datalist[i-1].num_nodes:
# append when the number of nodes changes
start_pos_data.append(i)
现在,start_pos_data
可以传递给BucketSampler
的构造函数初始化采样器
bucketSampler = BucketSampler(sorted_datalist, batch_size = 128, start_pos_data = start_pos_data)
在此之后,bucketSampler
可以作为 kwarg
到 DataLoader
构造函数传递给:
from torch_geometric.loader import DataLoader
dataloader = DataLoader(sorted_datalist, batch_sampler = bucketSampler)
此 dataloader
(在迭代时)将以所需方式生成批次。
假设我有一个列表,datalist
,其中包含几个示例(对于我的用例,它们的类型为 torch_geometric.data.Data
)。每个示例都有一个属性 num_nodes
出于演示目的,可以使用以下代码片段
创建这样的datalist
import torch
from torch_geometric.data import Data # each example is of this type
import networkx as nx # for creating random data
import numpy as np
# the python list containing the examples
datalist = []
for num_node in [9, 11]:
for _ in range(1024):
edge_index = torch.from_numpy(
np.array(nx.fast_gnp_random_graph(num_node, 0.5).edges())
).t().contiguous()
datalist.append(
Data(
x=torch.rand(num_node, 5),
edge_index=edge_index,
edge_attr=torch.rand(edge_index.size(1))
)
)
从上面的 datalist
对象,我可以通过使用 DataLoader constructor 为:
from torch_geometric.loader import DataLoader
dataloader = DataLoader(
datalist, batch_size=128, shuffle=True
)
我的问题是,如何使用 DataLoader
class 来确保给定批次中的每个示例都具有相同的 num_nodes
属性值?
PS:
我试图解决它并通过使用 combine_iterators
函数片段组合多个 DataLoader
对象提出了一个 hacky 解决方案,如下所示:
def get_combined_iterator(*iterables):
nexts = [iter(iterable).__next__ for iterable in iterables]
while nexts:
next = random.choice(nexts)
try:
yield next()
except StopIteration:
nexts.remove(next)
datalists = defaultdict(list)
for data in datalist:
datalists[data.num_nodes].append(data)
dataloaders = (
DataLoader(data, batch_size=128, shuffle=True) for data in datalists.values()
)
batches = get_combined_iterator(*dataloaders)
但是,我认为一定有一些 elegant/better 方法可以做到这一点,因此出现了这个问题。
如果您的基础数据集是 map-style,您可以定义一个 torch.utils.data.Sampler
,其中 returns 您想要一起批处理的示例的索引。它的一个实例将作为 batch_sampler
kwarg 传递给您的 DataLoader
并且您可以删除 batch_size
kwarg,因为采样器将根据您的实现方式为您形成批次。
按照 erip's suggestion, I subclassed torch.utils.data.sampler.Sampler
创建一个新的采样器:BucketSampler
它使用 torch.utils.data.sampler.SubsetRandomSampler
和 torch.utils.data.sampler.BatchSampler
来实现对给定属性具有相同值的示例进行批处理。
import torch
from torch.utils.data.sampler import Sampler, BatchSampler, SubsetRandomSampler
class BucketSampler(Sampler):
def __init__(self, dataset, batch_size, start_pos_data, generator=None) -> None:
self.dataset = dataset
self.batch_size = batch_size
self.generator = generator
start_pos_data = start_pos_data
start_end_indices = []
for i in range(len(start_pos_data) - 1):
start_end_indices.append((start_pos_data[i], start_pos_data[i+1]))
start_end_indices.append((start_pos_data[-1], len(self.dataset)))
ranges = [range(start, end) for start, end in start_end_indices]
subset_samplers = [SubsetRandomSampler(range_, generator=generator) for range_ in ranges]
self.samplers = [
BatchSampler(subset_sampler, batch_size, drop_last=False) for subset_sampler in subset_samplers
]
self._len = 0
for sampler in self.samplers:
self._len += len(sampler)
def __iter__(self):
iterators = [iter(sampler) for sampler in self.samplers]
while iterators:
randint = torch.randint(0, len(iterators),size=(1,), generator=self.generator)[0]
try:
yield next(iterators[randint])
except StopIteration:
iterators.pop(randint)
def __len__(self):
return self._len
除了通常的参数之外,这个 class 也将 start_pos_data
作为参数,它是一个列表,其中包含 datalist
中的第一个索引(来自问题中给出的示例的数据集) 属性值更改的位置。因此,对于上面的示例,我们可以借助以下代码片段创建这样一个列表:
# sort datalist to ensure that data items with the same number of nodes are grouped together
sorted_datalist = sorted(datalist, key = lambda data: data.num_nodes)
# initialize the start_pos_data by 0
start_pos_data = [0]
for i in range(1,len(sorted_datalist)):
if sorted_datalist[i].num_nodes != sorted_datalist[i-1].num_nodes:
# append when the number of nodes changes
start_pos_data.append(i)
现在,start_pos_data
可以传递给BucketSampler
的构造函数初始化采样器
bucketSampler = BucketSampler(sorted_datalist, batch_size = 128, start_pos_data = start_pos_data)
在此之后,bucketSampler
可以作为 kwarg
到 DataLoader
构造函数传递给:
from torch_geometric.loader import DataLoader
dataloader = DataLoader(sorted_datalist, batch_sampler = bucketSampler)
此 dataloader
(在迭代时)将以所需方式生成批次。