翻转 TF 数据集的标签

Flipping the labels of a TF dataset

我想为 CIFAR-100 创建一个恶意数据集来测试类似于这个 EMNIST 恶意数据集的联合学习攻击:

url_malicious_dataset = 'https://storage.googleapis.com/tff-experiments-public/targeted_attack/emnist_malicious/emnist_target.mat'
filename = 'emnist_target.mat'
path = tf.keras.utils.get_file(filename, url_malicious_dataset)
emnist_target_data = io.loadmat(path)

我尝试了以下方法将提取的示例数据集中的标签 0 翻转为 4,但此方法不起作用:

cifar_train, cifar_test = tff.simulation.datasets.cifar100.load_data(cache_dir=None)
example_dataset = cifar_train.create_tf_dataset_for_client(cifar_train.client_ids[0])
for example in example_dataset:
  if example['label'].numpy() == 0:
    example['label'] = tf.constant(4,dtype=tf.int64)

知道如何通过正确翻转标签为 CIFAR-100 而不是 EMNIST 创建类似版本的恶意数据集吗?

一般来说,tf.data.Dataset 对象可以使用它们的 .map 方法进行修改。因此,例如,一个简单的标签翻转可以按如下方式完成:

def flip_label(example):
  return {'image': example['image'], 'label': 99-example['label']}

flipped_dataset = example_dataset.map(flip_label)

这会反转标签 0-99。您可以执行类似于将 0 发送到 4 并修复所有其他标签的操作。

请注意,如果您想将此应用到 cifar_train 中的所有客户端数据集,则必须使用 tff.simulation.datasets.ClientData.preprocess 方法。也就是说,您可以执行类似 cifar_train.preprocess(lambda x: x.map(flip_label)).

的操作