如何摆脱 placements(SERVER 或 CLIENTS)以便我可以将 float32@SERVER 转换为 float32?

How to get rid of placements(SERVER or CLIENTS) so that I can transform float32@SERVER to float32?

我正在尝试构建您自己的联合学习算法教程中的学习率衰减挑战。我使用了以下代码

import nest_asyncio
nest_asyncio.apply()

import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

NUM_CLIENTS = 10
BATCH_SIZE = 20
initial_lr = 0.01
decay_rate = 0.0005
minimum_lr = initial_lr/2

def preprocess(dataset):
    def batch_format_fn(element):
        return(tf.reshape(element['pixels'],[-1,784]),
              tf.reshape(element['label'],[-1,1]))
    return dataset.batch(BATCH_SIZE).map(batch_format_fn)

client_ids = np.random.choice(emnist_train.client_ids,
                              size=NUM_CLIENTS, replace=False)

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
                       for x in client_ids]

def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=(784,)),
        tf.keras.layers.Dense(10, kernel_initializer='zeros'),
        tf.keras.layers.Softmax(),
    ])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
    client_weights = model.trainable_variables
    tf.nest.map_structure(lambda x,y: x.assign(y),
                         client_weights, server_weights)
    
    for batch in dataset:
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, client_weights)
        grads = tf.clip_by_global_norm(grads, 5.0)[0]
        grads_and_vars = zip(grads, client_weights)
        client_optimizer.apply_gradients(grads_and_vars)
    
    return client_weights

@tf.function
def server_update(model, mean_client_weights):
    model_weights = model.trainable_variables
    tf.nest.map_structure(lambda x,y: x.assign(y),
                         model_weights, mean_client_weights)
    
    return model_weights

@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

@tff.federated_computation
def initialize_fn():
    return [tff.federated_value(server_init(), tff.SERVER), tff.federated_value(initial_lr, tff.SERVER)]
    #return tff.federated_value([server_init(),initial_lr], tff.SERVER)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
str(tf_dataset_type)

model_weights_type = server_init.type_signature.result
str(model_weights_type)

@tff.tf_computation(tf_dataset_type, model_weights_type,tf.float32)
def client_update_fn(tf_dataset, server_weights, LR):
    model = model_fn()
    client_optimizer=tf.keras.optimizers.SGD(learning_rate=LR)
    return client_update(model, tf_dataset, server_weights, client_optimizer)

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
    model = model_fn()
    return server_update(model, mean_client_weights)

federated_server_type = tff.FederatedType(model_weights_type,
                                         tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                          tff.CLIENTS)
#federated_server_type_with_LR = tff.FederatedType([model_weights_type,tff.to_type((tf.float32))],tff.SERVER)
federated_server_type_with_LR = [tff.FederatedType(model_weights_type,tff.SERVER),
                                 tff.FederatedType(tff.to_type((tf.float32)),tff.SERVER)]

@tf.function
def decay_lr(lr):
    if lr-decay_rate > minimum_lr:
        return lr-decay_rate
    else:
        return minimum_lr

@tff.tf_computation(tf.float32)
def decay_lr_fn(lr):
    return decay_lr(lr)

@tff.federated_computation(federated_server_type_with_LR, federated_dataset_type)
def next_fn(server_weights_and_LR, federated_dataset):
    
    server_weights = server_weights_and_LR[0]
    #LR_SERVER = server_weights_and_LR[1]
    #LR_CLIENTS = tff.federated_broadcast(server_weights_and_LR[1])
    
    LR = server_weights_and_LR[1]
    LR_NEW = tff.federated_map(decay_lr_fn, LR)
    LR_NEW_CLIENTS = tff.federated_broadcast(LR_NEW)
    
    # Broadcast the server weights to the clients
    server_weights_at_client = tff.federated_broadcast(server_weights)
    
    
    # Each client computes their updated weights
    client_weights = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client, LR_NEW_CLIENTS))
    
    # The server averages are updated
    mean_client_weights = tff.federated_mean(client_weights)
    
    # The surver update
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)
    
    #return server_weights_and_LR
    return [server_weights, LR_NEW]

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn)

sorted_client_ids = sorted(emnist_test.client_ids)
sorted_client_ids2 = sorted_client_ids[0:100]

def data(client, source=emnist_test):
    return preprocess(source.create_tf_dataset_for_client(client))
central_emnist_test = (tf.data.Dataset.from_tensor_slices(
    [data(client) for client in sorted_client_ids2])).flat_map(lambda x: x)

def evaluate(server_state):
    keras_model = create_keras_model()
    keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
    )
    keras_model.set_weights(server_state)
    keras_model.evaluate(central_emnist_test)

server_state = federated_algorithm.initialize()
evaluate(server_state[0])

for round in range(15):
    print(round)
    #server_state_temp = federated_algorithm.next(server_state, federated_train_data)
    #server_state = [server_state_temp[0], decaying_lr(round)]
    server_state = federated_algorithm.next(server_state, federated_train_data)
    print(server_state[1])

evaluate(server_state[0])

此代码工作正常,但我想将学习率定义添加到 server_init() 函数。所以基本上有以下

@tff.tf_computation
def server_init():
    model = model_fn()
    return [model.trainable_variables, initial_lr]

@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER)

但是这样做会导致以下问题

The return type of `initialize_fn` must be assignable to the first input argument of `next_fn`, but:
`initialize_fn` returned type:
<<float32[784,10],float32[10]>,float32>@SERVER
and the first input argument of `next_fn` is:
<server_weights_and_LR=<<float32[784,10],float32[10]>@SERVER,float32@SERVER>,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS>

问题是 return [server_weights, LR_NEW] next_fn() 末尾的代码有 @SERVER,float32@SERVER> 类型。 server_weights 和 LR_NEW 都已经有@SERVER 位置。目前

@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

@tff.federated_computation
def initialize_fn():
    return [tff.federated_value(server_init(), tff.SERVER), tff.federated_value(initial_lr, tff.SERVER)]

也 returns @SERVER,float32@SERVER>

但正如我所说,我想更改该部分,因此我想删除 next_fn 中 server_weight 和 LR_NEW 的位置,并将位置应用于包含的列表两者。我该怎么做?

还有人有解决这个挑战的“更清洁”的解决方案吗?

编辑:

我只想澄清 initialize/input 的输入输出匹配,接下来是“循环”。所以我们寻求初始化输出和下一个输入之间的匹配,但也需要下一个输出和输入参数之间的匹配。

The first return argument of `next_fn` must be assignable to its first input argument, but found
`next_fn` which returns type:
<<float32[784,10],float32[10]>@SERVER,float32@SERVER>
which does not match its first input argument:
<<float32[784,10],float32[10]>,float32>@SERVER

您的代码中的问题是在手动创建 federated_server_type_with_LR.

在类型系统中,<A@SERVER, B@SERVER>不同于<A, B>@SERVER。您可以使用 tff.federated_zip() 将前者转换为后者,从而将展示位置提升至顶层。

两种解决方案:

(1)修改next_fn的装饰器为@tff.federated_computation(tff.federated_zip(federated_server_type_with_LR), federated_dataset_type)

(2) [首选,以避免此类问题] 不要手动创建类型,而是从 initialize_fn 读取它。装饰器将是 @tff.federated_computation(initialize_fn.type_signature.result, federated_dataset_type)