将 PyG 图转换为 NetworkX 图

Converting a PyG graph to a NetworkX graph

我正在尝试使用 to_networkx

将我的 PyG 图转换为 NetworkX 图

根据 docs 除了 Data 对象之外,我还可以选择将节点和边缘属性作为 str 可迭代对象传递。

以下是按节点和边的属性列表,值已转换为字符串:

Nodes:  ['3.3375725746154785', '2.0086510181427',..., '1.5960148572921753', '3.621992349624634']

Edges:  ['0.9940207804344958', '0.48573804411542043', ..., '0.7245483440145621', '0.24117984598949904']

to_networkx 当我只将数据对象传递给它时运行良好。但是,当我也传递这些属性列表时,出现以下错误:

G[u][v][key] = values[key][i]
KeyError: '0.30194718370332896'

我查看了源代码,但无法弄清楚它在做什么。有人可以帮助解释我的属性列表有什么问题以及我需要更改哪些内容才能被接受。

我可以看出这个错误是专门指我的边缘属性。如果删除它们,我会收到以下与节点属性相关的类似错误:

feat_dict.update({key: values[key][i]})
KeyError: '0.0'

我如何构建图表并将其传递给 to_networkx:

n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
n2 = np.array([0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4])
cat = np.stack((n1,n2), axis=0)
e = torch.tensor(cat, dtype=torch.long)
edge_index = e.t().clone().detach()
edge_attr = torch.tensor(np.random.rand(35,1))

x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr = edge_attr)

在传递节点和边缘属性之前,我进行了字符串转换以符合 str 可迭代要求:

networkx_node_values = list(map(str, data.x.t()[0].tolist()))
networkx_edge_values = list(map(str, edge_attr.t()[0].tolist()))
    
networkX_graph = to_networkx(data, node_attrs = networkx_node_values, edge_attrs = networkx_edge_values)

您需要将属性名称作为列表传递:

to_networkx(<PyTorchGeometricDataObject>, node_attrs=[<Name of Node Attribute 1>, <Name of Node Attributes 2>, ... ], edge_attr=[<Edge Attribute 1>, ...])

或者在上下文中,根据您给出的最小示例:

import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
n2 = np.array([0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4])
cat = np.stack((n1,n2), axis=0)
e = torch.tensor(cat, dtype=torch.long)
edge_index = e.t().clone().detach()
edge_attr = torch.tensor(np.random.rand(35,1))

x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr = edge_attr)
print(data)
# Data(edge_attr=[35, 1], edge_index=[2, 35], x=[7, 1])

networkX_graph = to_networkx(data, node_attrs=["x"], edge_attrs=["edge_attr"])

print(networkX_graph.nodes(data=True))
# [(0, {'x': 0.0}), (1, {'x': 0.0}),...
print(networkX_graph.edges(data=True))
# [(0, 0, {'edge_attr': 0.3412137594357493}), ...