加快自定义消息丢失的 pytorch 操作
Speeding up pytorch operations for custom message dropout
我正在尝试在 PyTorch Geometric 中的自定义 MessagePassing 卷积中实现消息丢失。消息丢失包括随机忽略图中 p% 的边。我的想法是从 forward()
.
中的输入 edge_index
中随机删除其中的 p%
edge_index
是形状为 (2, num_edges)
的张量,其中第一个维度是“从”节点 ID,第二个维度是“到”节点 ID。所以我认为我可以do 是 select range(N)
的随机样本,然后用它来掩盖其余索引:
def forward(self, x, edge_index, edge_attr=None):
if self.message_dropout is not None:
# TODO: this is way too slow (4-5 times slower than without it)
# message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
edge_index_to_use = edge_index[:, random_keep_inx]
edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
else:
edge_index_to_use = edge_index
edge_attr_to_use = edge_attr
...
但是,它太慢了,它使一个纪元进行了 5' 而不是 1' 没有(慢 5 倍)。在 PyTorch 中有更快的方法吗?
编辑:瓶颈似乎是 random.sample()
调用,而不是屏蔽。所以我想我应该问的是更快的替代品。
我设法使用 PyTorch 的函数式 Dropout 创建了一个布尔掩码,速度要快得多。现在一个纪元又需要 ~1'。比我在其他地方找到的具有排列的其他解决方案更好。
def forward(self, x, edge_index, edge_attr=None):
if self.message_dropout is not None:
# message dropout -> randomly ignore p % of edges in the graph
mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
edge_index_to_use = edge_index[:, mask]
edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
else:
edge_index_to_use = edge_index
edge_attr_to_use = edge_attr
...
我正在尝试在 PyTorch Geometric 中的自定义 MessagePassing 卷积中实现消息丢失。消息丢失包括随机忽略图中 p% 的边。我的想法是从 forward()
.
edge_index
中随机删除其中的 p%
edge_index
是形状为 (2, num_edges)
的张量,其中第一个维度是“从”节点 ID,第二个维度是“到”节点 ID。所以我认为我可以do 是 select range(N)
的随机样本,然后用它来掩盖其余索引:
def forward(self, x, edge_index, edge_attr=None):
if self.message_dropout is not None:
# TODO: this is way too slow (4-5 times slower than without it)
# message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
edge_index_to_use = edge_index[:, random_keep_inx]
edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
else:
edge_index_to_use = edge_index
edge_attr_to_use = edge_attr
...
但是,它太慢了,它使一个纪元进行了 5' 而不是 1' 没有(慢 5 倍)。在 PyTorch 中有更快的方法吗?
编辑:瓶颈似乎是 random.sample()
调用,而不是屏蔽。所以我想我应该问的是更快的替代品。
我设法使用 PyTorch 的函数式 Dropout 创建了一个布尔掩码,速度要快得多。现在一个纪元又需要 ~1'。比我在其他地方找到的具有排列的其他解决方案更好。
def forward(self, x, edge_index, edge_attr=None):
if self.message_dropout is not None:
# message dropout -> randomly ignore p % of edges in the graph
mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
edge_index_to_use = edge_index[:, mask]
edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
else:
edge_index_to_use = edge_index
edge_attr_to_use = edge_attr
...