Stellargraph 无法处理数据混洗

Stellargraph failing to work with data shuffle

当我 运行 StellarGraph 的 demo 使用 DGCNN 进行图形分类时,我得到了与演示中相同的结果。

但是,当我使用以下代码测试第一次打乱数据时发生的情况:

shuffler = list(zip(graphs, graph_labels))
random.shuffle(shuffler)
graphs, graph_labels = zip(*shuffler)

模型根本没有学习(准确率在 50% 左右——就像数据分布一样)。

有谁知道为什么会这样?也许我以错误的方式洗牌?还是首先应该对数据进行整理(也是为什么?这没有任何意义)?还是 StellarGraph 实现中的错误?

我发现了问题。这与改组算法无关,也与 StellarGraph 的实现无关。问题出在演示中,在以下几行:

train_gen = gen.flow(
    list(train_graphs.index - 1),
    targets=train_graphs.values,
    batch_size=50,
    symmetric_normalization=False,
)

test_gen = gen.flow(
    list(test_graphs.index - 1),
    targets=test_graphs.values,
    batch_size=1,
    symmetric_normalization=False,
)

问题是由 train_graphs.index - 1test_graphs.index - 1 引起的。索引已经在 0n 之间的范围内,因此从中减去一个会导致图形数据向后“移动”一个,导致每个数据点获得不同数据点的标签.

要解决此问题,只需将它们更改为 train_graphs.indextest_graphs.index,最后不带 -1