当 运行 "Building Your Own Federated Learning Algorithm" 教程时,Tensorflow federated (TFF) 0.19 的性能明显低于 TFF 0.17

Tensorflow federated (TFF) 0.19 performs significantly worse than TFF 0.17 when running "Building Your Own Federated Learning Algorithm" tutorial

在“建立你自己的联邦学习算法”教程的最后,它说,在训练我们的模型 15 轮之后,我们将期望 sparse_categorical_accuracy大约 0.25,但根据我的 运行s,运行按原样在 colab 中使用教程给出的结果在 0.09 和 0.11 之间。然而,只需将 tf 和 tff 版本分别更改为 2.3.x 和 0.17,就会得到大约 0.25 的结果,正如我们预期的那样!

要按原样复制运行 上述教程,它应该使用 tf 2.5 和 tff 0.19。之后 运行 相同的教程,只需更改

!pip install --quiet --upgrade tensorflow-federated

!pip install --quiet tensorflow==2.3.0
!pip install --quiet --upgrade tensorflow-federated==0.17.0

此外,tf 2.4 和 tff 0.18 组合工作得很好,得分约为 0.25。所以只有 tf 2.5 和 tff 0.19 的组合没有给出预期的结果。

需要说明的是,我并不是说第一次设置不会训练模型; 运行使用它进行 200 轮显示分数稳步提高,达到 0.7-0.8 左右。我将不胜感激澄清为什么会这样,或者如果我做错了什么请指出。

编辑: 为了确保在不同的 tff 版本中使用相同的客户端,我使用了以下代码

训练数据

sorted_client_ids = sorted(emnist_train.client_ids)
sorted_client_ids2 = sorted_client_ids[0:10]

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
                       for x in sorted_client_ids2]

测试数据

sorted_client_ids = sorted(emnist_test.client_ids)
sorted_client_ids2 = sorted_client_ids[0:100]

def data(client, source=emnist_test):
    return preprocess(source.create_tf_dataset_for_client(client))

central_emnist_test = (tf.data.Dataset.from_tensor_slices(
    [data(client) for client in sorted_client_ids2])).flat_map(lambda x: x)

我每人训练了 50 轮。我使用这些设置得到的结果是

对于 tff 0.17:损失:1.8676 - sparse_categorical_accuracy:0.5115

对于 tff 0.18:损失:1.8503 - sparse_categorical_accuracy:0.5160

对于 tff 0.19:损失:2.2007 - sparse_categorical_accuracy:0.1014

所以我的问题是所有三个版本的 tff 都使用相同的训练数据、相同的测试数据、模型具有相同的初始化和相同轮次的训练,但 tff 0.19 和 tff 0.18/0.17 的结果大不相同,而 tff 0.18 和 0.17 产生了非常相似的结果。

再次澄清 tff 0.19 也提高了它的准确性,但程度要小得多。

编辑 2: 根据 Zachary Charles 的建议,我使用了联合 sgd。对于 tff 0.18 和 0.17,编辑第一行。

!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()

import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

print(tf.__version__)
print(tff.__version__)

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):
    def batch_format_fn(element):
        return(tf.reshape(element['pixels'],[-1,784]),
              tf.reshape(element['label'],[-1,1]))
    return dataset.batch(BATCH_SIZE).map(batch_format_fn)

sorted_client_ids = sorted(emnist_train.client_ids)
sorted_client_ids2 = sorted_client_ids[0:10]

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
                       for x in sorted_client_ids2]

def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=(784,)),
        tf.keras.layers.Dense(10, kernel_initializer='zeros'),
        tf.keras.layers.Softmax(),
    ])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    
sorted_client_ids = sorted(emnist_test.client_ids)
sorted_client_ids2 = sorted_client_ids[0:10]

def data(client, source=emnist_test):
    return preprocess(source.create_tf_dataset_for_client(client))

central_emnist_test = (tf.data.Dataset.from_tensor_slices(
    [data(client) for client in sorted_client_ids2])).flat_map(lambda x: x)

def evaluate(server_state):
    keras_model = create_keras_model()
    keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
    )
    keras_model.set_weights(server_state)
    keras_model.evaluate(central_emnist_test)


iterative_process = tff.learning.build_federated_sgd_process(
    model_fn,
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

state = iterative_process.initialize()
evaluate(state.model.trainable)

for round in range(50):
    print(round)
    state,_ = iterative_process.next(state, federated_train_data)

evaluate(state.model.trainable)

我得到的结果是

训练前

训练后

TFF 0.19 将提供的数据集(包括本教程中使用的 EMNIST)从 HDF5 支持的实现移至 SQL 支持的实现 (commit)。这可能会改变客户端的顺序,这将改变教程中用于训练的客户端。

值得注意的是,在大多数模拟中,这应该不会改变任何东西。通常应在每一轮对客户进行随机抽样(出于说明的原因,本教程中未这样做)并且通常至少应完成 100 轮(如您所说)。

我将更新教程以通过对客户端 ID 进行排序,然后按顺序选择它们来保证可重复性。

对于任何感兴趣的人,更好的做法是 a) 对客户端 ID 进行排序,然后 b) 使用类似 np.random.RandomState 的方式进行采样,如以下代码段所示:

emnist_train, _ = tff.simulation.datasets.emnist.load_data()
random_state = np.random.RandomState(seed=1729)
sorted_client_ids = sorted(emnist_train.client_ids)
sampled_client_ids = random_state.choice(sorted_client_ids, size=NUM_CLIENTS)