Pytorch:Node2Sec:TypeError:元组索引必须是整数或切片,而不是元组



我正在尝试从torch_geometry.nn库运行Node2Vec。作为参考,我以这个例子为例。

在运行train((函数时,我一直得到TypeError: tuple indices must be integers or slices, not tuple

我将torch version 1.6.0CUDA 10.1以及最新版本的torch-scattertorch-sparsetorch-clustertorch-spline-convtorch-geometric一起使用。

以下是详细的错误:

错误的第1部分

错误的第2部分

谢谢你的帮助。

错误是由于torch.ops.torch_cluster.random_walk返回的是元组而不是数组/张量。我通过将torch_geometric.nn.Node2Vec中的函数pos_sampleneg_sample替换为这些函数来修复它。

def pos_sample(self, batch):
batch = batch.repeat(self.walks_per_node)
rowptr, col, _ = self.adj.csr()
rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q)
if not isinstance(rw, torch.Tensor):
rw = rw[0]
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)

def neg_sample(self, batch):
batch = batch.repeat(self.walks_per_node * self.num_negative_samples)
rw = torch.randint(self.adj.sparse_size(0),
(batch.size(0), self.walk_length))
rw = torch.cat([batch.view(-1, 1), rw], dim=-1)
walks = []
num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
for j in range(num_walks_per_rw):
walks.append(rw[:, j:j + self.context_size])
return torch.cat(walks, dim=0)

请参阅PyTorch Node2Verc文档。

最新更新