如何在pytorch中生成图形数据



我正在创建图模型(GCN(。由于我不习惯PyTorch,所以在设置数据集时遇到了问题。我想制作自定义的图形数据。

我不知道为什么我的代码不能正常工作:

class MoneyGraph(Dataset):
def __init__(self):
# node features
node_features = np.loadtxt(
"./datas/node_features.csv",
encoding="utf-8",
delimiter=",",
dtype=np.float32,
skiprows=1,
)
self.x = torch.tensor(node_features, dtype=torch.float)
# label
labels = [1, 1, 1, 0, 0, 1, 0, 0, 0, 1]
self.y = torch.tensor(labels, dtype=torch.float)
# edge index
target_nodes = [1, 2, 3, 5, 3, 5, 7, 1, 7]
source_nodes = [9, 2, 2, 1, 7, 7, 3, 4, 2]
self.edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)
# etc info
self.num_node_features = 5
self.num_classes = 2
self.num_nodes = len(self.x)
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return len(self.x)
def is_undirected(self):
return False

我试着运行这个:

loader = DataLoader(dataset, batch_size=5, shuffle=True)
for batch in loader:
print(batch.x)

我收到以下错误消息:

属性错误:"list"对象没有属性"x"

数据集的__getitem__函数返回一个由两个元素组成的元组。为了访问它们,您需要执行batch[0]batch[1]以分别获得self.xself.y的元素。

或者,您可以直接从迭代器进行析构函数:

for x, y in loader:
print(x)
print(y)

其中xy将分别具有五个元素。数据加载器为您处理批处理元素的析构函数。

最新更新