如何在Tensorboard中可视化图神经网络的模型图



我正在尝试将我用来预测分子性质的图神经网络的计算图可视化。该模型是在PyTorch中创建的,并将DGL图作为输入。试图可视化模型的代码片段如下所示:

train_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/train'
train_summary_writer = tensorboardX.SummaryWriter(train_log_dir)
train_summary_writer.add_graph(model, [transformer(dataset[0][0]), transformer(dataset[0][0])])

我遇到以下错误,TensorBoardX无法可视化图模型,拒绝接受DGL图作为输入,只需要张量。我能把这个模型形象化吗?

RuntimeError: Tracer cannot infer type of (Graph(num_nodes=3, num_edges=4,
ndata_schemes={'x': Scheme(shape=(10,), dtype=torch.float32)}
edata_schemes={'w': Scheme(shape=(4,), dtype=torch.float32)}), Graph(num_nodes=3, num_edges=4,
ndata_schemes={'x': Scheme(shape=(10,), dtype=torch.float32)}
edata_schemes={'w': Scheme(shape=(4,), dtype=torch.float32)}))
:Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type DGLHeteroGraph.
Process finished with exit code 1

我通常使用火炬库中的SummaryWriter。它的工作原理是这样的:

...
from torch.utils.tensorboard import SummaryWriter
...
# initializing your model
model = ...
dummy_input = ...
...
writer = SummaryWriter(f'logs/net')
writer.add_graph(model, dummy_input)

然后在终端运行python脚本后运行:

tensorboard --logdir logs

,然后它抛出链接类似localhost:6006,这将是您的可视化图形模型。更多信息:https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html

最新更新