torch_geometry.nn.GATConv中的断言错误



我试图在torch_geometric中使用图注意力网络(GAT(模块,但使用以下代码继续运行到AssertionError: Static graphs not supported in 'GATConv'

class GraphConv_sum(nn.Module):
def __init__(self, in_ch, out_ch, num_layers, block, adj):
super(GraphConv_sum, self).__init__()
adj_coo = coo_matrix(adj) # convert the adjacency matrix to COO format for Pytorch Geometric
self.edge_index = torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long)
self.g_conv = nn.ModuleList()

self.act = nn.LeakyReLU()
for n in range(num_layers):
if n == 0:
self.g_conv.append(block(in_ch, 16))
elif n > 0 and n < num_layers - 1:
self.g_conv.append(block(16, 16))
else:
self.g_conv.append(block(16, out_ch))
def forward(self, x):
for layer in self.g_conv:
x = layer(x=x, edge_index=self.edge_index)
x = self.act(x)
print(x.shape)
return x[:, 0, :]

当我用GATConv替换block,然后是标准的训练循环时,就会发生这种错误(其他conv层,如GCNConvSAGEConv没有任何问题(。我检查了文档并确保输入的形状是正确的(其他conv层也是如此(。

在源代码中,forward方法中有assert x.dim() == 2, "Static graphs not supported in 'GATConv'"部分,但显然批处理维度将在前向传递中发挥作用,x.dim()将为3。具有批次维度的输入形状为[1024,6200]。然而,如果我手动将断言条件更改为x.dim() == 3,则仍然会引发相同的错误,就好像条件不满足一样。我对GAT只有一个高级的了解,所以可能有一些我缺少的东西。无论如何,我有几个问题从这个

  • 我方是否存在任何可能导致此错误的实现错误
  • 这个断言条件是为了什么?这种情况下的静态图是什么

如果有任何见解和帮助,我将不胜感激!!谢谢

事实证明,由于注意力权重计算,GATConv不支持多个特征矩阵和单个edge_index。更多信息:https://github.com/pyg-team/pytorch_geometric/issues/2844

相关内容

  • 没有找到相关文章

最新更新