Stellargraph无法处理数据混洗



当我使用DGCNNs运行StellarGraph的图形分类演示时,我得到了与演示中相同的结果。

然而,当我测试第一次使用以下代码对数据进行混洗时会发生什么:

shuffler = list(zip(graphs, graph_labels))
random.shuffle(shuffler)
graphs, graph_labels = zip(*shuffler)

该模型根本没有学习(准确率约为50%,就像数据分布一样(。

有人知道为什么会发生这种事吗?也许我拖着脚走错了?还是说数据一开始就应该取消隐藏(也是为什么?这没有任何意义(?或者这是StellarGraph实现中的一个错误?

我发现了问题。这与混洗算法无关,也与StellarGraph的实现无关。问题出现在演示中,位于以下行:

train_gen = gen.flow(
list(train_graphs.index - 1),
targets=train_graphs.values,
batch_size=50,
symmetric_normalization=False,
)
test_gen = gen.flow(
list(test_graphs.index - 1),
targets=test_graphs.values,
batch_size=1,
symmetric_normalization=False,
)

问题是由train_graphs.index - 1test_graphs.index - 1引起的。索引已经在0n之间的范围内,因此从中减去一将导致图数据为"0";移位";一个向后,导致每个数据点获得不同数据点的标签。

要解决此问题,只需将它们更改为train_graphs.indextest_graphs.index,而不在末尾添加-1

最新更新