我想试试这里演示的链接预测功能。以下是我的版本:
PyTorch Geometric v2.0.2
PyTorch v1.9.0+cu111
我很困惑为什么为每个张量打印cuda:0
,但当我通过RandomLinkSplit
传递数据时,我看到了错误。
import torch
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import negative_sampling
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
])
dataset = Planetoid(root='/tmp/Planetoid', name='Cora', transform=transform)
data = dataset[0]
print(data.to_dict())
print(data.keys)
transform = T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,)
train_data, val_data, test_data = transform(data)
输出:
{'x': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0'), 'edge_index': tensor([[ 0, 0, 0, ..., 2707, 2707, 2707],
[ 633, 1862, 2582, ..., 598, 1473, 2706]], device='cuda:0'), 'y': tensor([3, 4, 4, ..., 3, 3, 3], device='cuda:0'), 'train_mask': tensor([ True, True, True, ..., False, False, False], device='cuda:0'), 'val_mask': tensor([False, False, False, ..., False, False, False], device='cuda:0'), 'test_mask': tensor([False, False, False, ..., True, True, True], device='cuda:0')}
['val_mask', 'test_mask', 'edge_index', 'train_mask', 'x', 'y']
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_72/414574324.py in <module>
20
21 transform = T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,)
---> 22 train_data, val_data, test_data = transform(data)
/usr/local/lib/python3.7/dist-packages/torch_geometric/transforms/random_link_split.py in __call__(self, data)
204 train_edges,
205 neg_edge_index[:, num_neg_val + num_neg_test:],
--> 206 out=train_store,
207 )
208 self._create_label(
/usr/local/lib/python3.7/dist-packages/torch_geometric/transforms/random_link_split.py in _create_label(self, store, index, neg_edge_index, out)
284 if neg_edge_index.numel() > 0:
285 edge_label = torch.cat([edge_label, neg_edge_label], dim=0)
--> 286 edge_index = torch.cat([edge_index, neg_edge_index], dim=-1)
287 out[self.key] = edge_label
288 out[f'{self.key}_index'] = edge_index
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument tensors in method wrapper__cat)
这个问题确实是一个bug。感谢您的报道。