TensorFlow Federated:如何调整联合数据集中的非独立同分布性?

TensorFlow Federated: How to tune non-IIDness in federated dataset?

我正在 TensorFlow Federated (TFF) 中测试一些算法。在这方面,我想在具有不同“级别”的数据异质性(即非独立同分布性)的同一联合数据集上测试和比较它们。

因此,我想知道是否有任何方法可以自动或半自动地控制和调整特定联合数据集中非独立同分布的“级别”,例如通过 TFF APIs 或传统的 TF API(可能在 Dataset utils 内部)。

更实际一点:比如TFF提供的EMNIST联邦数据集,有3383个客户端,每个客户端都有自己的手写字符。然而,这些本地数据集在本地示例的数量和表示的 classes 方面似乎相当平衡(所有 classes 或多或少都在本地表示)。 如果我想要一个联合数据集(例如,从 TFF 的 EMNIST 开始),即:

我应该如何在 TFF 框架内进行以准备具有这些特征的联合数据集?

我应该手工完成所有工作吗?或者你们中的一些人有一些自动化这个过程的建议吗?

另一个问题:在 Hsu 等人的这篇论文 "Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification" 中,他们利用 Dirichlet 分布来合成一群不相同的客户端,并且他们使用 浓度参数 来控制客户端之间的相同性。这似乎是一种易于调整的方法来生成具有不同异质性水平的数据集。关于如何在 TFF 框架内或仅在 TensorFlow (Python) 中考虑像 EMNIST 这样的简单数据集实施此策略(或类似策略)的任何建议也将非常有用。

非常感谢。

对于联合学习模拟,在实验驱动程序中 Python 中设置客户端数据集以实现所需的分布是非常合理的。在某些高层,TFF 处理建模数据位置(类型系统中的“放置”)和计算逻辑。 Re-mixing/generating 模拟数据集并不是该库的核心,尽管您已经找到了有用的库。通过操纵 tf.data.Dataset 然后将客户端数据集“推送”到 TFF 计算中直接在 python 中执行此操作似乎很简单。

标签非独立同分布

是的,tff.simulation.datasets.build_single_label_dataset 就是为了这个目的。

它需要一个 tf.data.Dataset 并基本上过滤掉所有与 label_keydesired_label 值不匹配的示例(假设数据集产生 dict 类似的结构) .

对于 EMNIST,要创建 所有 个数据集(无论用户如何),可以通过以下方式实现:

train_data, _ = tff.simulation.datasets.emnist.load_data()
ones = tff.simulation.datasets.build_single_label_dataset(
  train_data.create_tf_dataset_from_all_clients(),
  label_key='label', desired_label=1)
print(ones.element_spec)
>>> OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
print(next(iter(ones))['label'])
>>> tf.Tensor(1, shape=(), dtype=int32)

数据失衡

结合使用 tf.data.Dataset.repeat and tf.data.Dataset.take 可能会造成数据不平衡。

train_data, _ = tff.simulation.datasets.emnist.load_data()
datasets = [train_data.create_tf_dataset_for_client(id) for id in train_data.client_ids[:2]]
print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
>>> [93, 109]
datasets[0] = datasets[0].repeat(5)
datasets[1] = datasets[1].take(5)
print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
>>> [465, 5]