我正在尝试使用to_networkx
将PyG图转换为NetworkX图
根据文档,除了Data对象之外,我还可以选择性地将node和edge属性作为str-iterables传递。
以下是按节点和边缘的属性列表,值转换为字符串:
Nodes: ['3.3375725746154785', '2.0086510181427',..., '1.5960148572921753', '3.621992349624634']
Edges: ['0.9940207804344958', '0.48573804411542043', ..., '0.7245483440145621', '0.24117984598949904']
当我只将Data对象传递给to_networkx
时,它运行良好。然而,当我也传递这些属性列表时,我会得到以下错误:
G[u][v][key] = values[key][i]
KeyError: '0.30194718370332896'
我看过源代码,但不知道它在做什么。有人能帮我解释一下我的属性列表有什么问题,以及我需要更改什么才能被接受吗。
我可以确定的是,这个错误是专门针对我的边缘属性的。如果我删除它们,我会得到以下与节点属性相关的类似错误:
feat_dict.update({key: values[key][i]})
KeyError: '0.0'
如何构造图并将其传递给to_networkx
:
n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
n2 = np.array([0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4])
cat = np.stack((n1,n2), axis=0)
e = torch.tensor(cat, dtype=torch.long)
edge_index = e.t().clone().detach()
edge_attr = torch.tensor(np.random.rand(35,1))
x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr = edge_attr)
在传递node和edge属性之前,我会进行字符串转换,以符合str-iterable要求:
networkx_node_values = list(map(str, data.x.t()[0].tolist()))
networkx_edge_values = list(map(str, edge_attr.t()[0].tolist()))
networkX_graph = to_networkx(data, node_attrs = networkx_node_values, edge_attrs = networkx_edge_values)
您需要将属性的名称作为列表传递:
to_networkx(<PyTorchGeometricDataObject>, node_attrs=[<Name of Node Attribute 1>, <Name of Node Attributes 2>, ... ], edge_attr=[<Edge Attribute 1>, ...])
或者在上下文中,基于您给定的最小示例:
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
n2 = np.array([0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4])
cat = np.stack((n1,n2), axis=0)
e = torch.tensor(cat, dtype=torch.long)
edge_index = e.t().clone().detach()
edge_attr = torch.tensor(np.random.rand(35,1))
x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr = edge_attr)
print(data)
# Data(edge_attr=[35, 1], edge_index=[2, 35], x=[7, 1])
networkX_graph = to_networkx(data, node_attrs=["x"], edge_attrs=["edge_attr"])
print(networkX_graph.nodes(data=True))
# [(0, {'x': 0.0}), (1, {'x': 0.0}),...
print(networkX_graph.edges(data=True))
# [(0, 0, {'edge_attr': 0.3412137594357493}), ...