我正在创建图模型(GCN(。由于我不习惯PyTorch,所以在设置数据集时遇到了问题。我想制作自定义的图形数据。
我不知道为什么我的代码不能正常工作:
class MoneyGraph(Dataset):
def __init__(self):
# node features
node_features = np.loadtxt(
"./datas/node_features.csv",
encoding="utf-8",
delimiter=",",
dtype=np.float32,
skiprows=1,
)
self.x = torch.tensor(node_features, dtype=torch.float)
# label
labels = [1, 1, 1, 0, 0, 1, 0, 0, 0, 1]
self.y = torch.tensor(labels, dtype=torch.float)
# edge index
target_nodes = [1, 2, 3, 5, 3, 5, 7, 1, 7]
source_nodes = [9, 2, 2, 1, 7, 7, 3, 4, 2]
self.edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)
# etc info
self.num_node_features = 5
self.num_classes = 2
self.num_nodes = len(self.x)
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return len(self.x)
def is_undirected(self):
return False
我试着运行这个:
loader = DataLoader(dataset, batch_size=5, shuffle=True)
for batch in loader:
print(batch.x)
我收到以下错误消息:
属性错误:"list"对象没有属性"x"数据集的__getitem__
函数返回一个由两个元素组成的元组。为了访问它们,您需要执行batch[0]
和batch[1]
以分别获得self.x
和self.y
的元素。
或者,您可以直接从迭代器进行析构函数:
for x, y in loader:
print(x)
print(y)
其中x
和y
将分别具有五个元素。数据加载器为您处理批处理元素的析构函数。