for循环优化创建邻接矩阵

For loop optimization to create an adjacency matrix

我目前正在处理带有标记边的图形。 原始邻接矩阵是形状为 [n_nodes、n_nodes、n_edges] 的矩阵,其中如果节点 i 和 j 通过边 k 连接,则每个单元格 [i,j, k] 为 1 .

我需要创建原始图的反向图,其中节点变成边,边变成节点,所以我需要一个形状为 [n_edges、n_edges、[=24= 的新矩阵]],其中如果边 i 和 j 具有 k 作为公共顶点,则每个单元格 [i,j,k] 为 1。

下面的代码正确地完成了任务,但是使用 5 个嵌套的 for 循环太慢了,处理我必须处理的图表数量似乎需要大约 700 个小时。

有没有更好的实现方式?

n_nodes = extended_adj.shape[0]
n_edges = extended_adj.shape[2]
reversed_graph = torch.zeros(n_edges, n_edges, n_nodes, 1)
for i in range(n_nodes):
    for j in range(n_nodes):
        for k in range(n_edges):
            #If adj_mat[i][j][k] == 1 nodes i and j are connected with edge k
            #For this reason the edge k must be connected via node j to every outcoming edge of j
            if extended_adj[i][j][k] == 1:
                #Given node j, we need to loop through every other possible node (l)
                for l in range(n_nodes):
                    #For every other node, we need to check if they are connected by an edge (m)
                    for m in range(n_edges):
                        if extended_adj[j][l][m] == 1:
                            reversed_graph[k][m][j] = 1

先谢谢了。

与上面的评论相呼应,这种图形表示几乎肯定是麻烦且低效的。但尽管如此,让我们定义一个没有循环的矢量化解决方案,并尽可能使用张量视图,这对于计算更大的图应该是相当有效的。

为清楚起见,我们使用 [i,j,k] 索引 G(原始图形),使用 [i',j',k'] 索引 G'(新图形)。让我们将 n_edges 缩短为 e,将 n_nodes 缩短为 n

考虑二维矩阵 slice = torch.max(G,dim = 1)。在此切片的每个坐标 [a,b] 处,1 表示节点 a 通过边 b 连接到某个其他节点(我们不关心哪个节点)。

slice = torch.max(G,dim = 1)                                     # dimension [n,e]

我们正在寻找解决方案,但我们需要一个表达式来告诉我们 a 是否连接到边 b 和另一条边 c,对于所有边 c。我们可以通过扩展 slice、复制和转置并寻找两者之间的交集来映射所有组合 b,c

expanded_dim = [slice.shape[0],slice.shape[1],slice.shape[1]]    # value [n,e,e]

# two copies of slice, expanded on different dimensions
expanded_slice = slice.unsqueeze(1).expand(expanded_dim)         # dimension [n,e,e]
transpose_slice = slice.unsqueeze(2).expand(expanded_dim)        # dimension [n,e,e]

G = torch.bitwise_and(expanded_slice,transpose_slice).int()      # dimension [n,e,e]

G[i',j',k'] 现在等于 1,当且仅当节点 i' 通过边 j' 连接到某个其他节点,并且节点 i' 通过边 k' 连接到其他节点。如果 j' = k' 值是 1,只要该边的一个端点是 i'.

最后,我们重新排序维度以获得您想要的形式。

G = torch.permute(G,(1,2,0))           # dimension [e,e,n]