如何在 TFF 的服务器上收集所有客户端权重?

How to gather all client weights at server in TFF?

我正在尝试通过更改此 tutorial 中的代码来使用 TFF 实现自定义聚合。我想重写 next_fn 以便所有客户端权重都放在服务器上以供进一步计算。由于 federated_collect 已从 tff-nightly 中删除,我正在尝试使用 federated_aggregate.

来做到这一点

这是我目前拥有的:

def accumulate(x, y):
    x.append(y)
    return x


def merge(x, y):
    x.extend(y)
    return y


@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_state, federated_dataset):
    server_weights_at_client = tff.federated_broadcast(
        server_state.trainable_weights)
    client_deltas = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))

    z = []
    agg_result = tff.federated_aggregate(client_deltas, z,
                                         accumulate=tff.tf_computation(accumulate),
                                         merge=tff.tf_computation(merge),
                                         report=tff.tf_computation(lambda x: x))

    new_weights = do_smth_with_result(agg_result)
    server_state = tff.federated_map(
        server_update_fn, (server_state, new_weights))
    return server_state

但是这会导致以下异常:

  File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 351, in <module>
    def next_fn(server_state, federated_dataset):
  File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 494, in __call__
    wrapped_func = self._strategy(
  File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 222, in __call__
    result = fn_to_wrap(*args, **kwargs)
  File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 358, in next_fn
    agg_result = tff.federated_aggregate(client_deltas, z,
  File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/federated_context/intrinsics.py", line 140, in federated_aggregate
    raise TypeError(
TypeError: Expected parameter `accumulate` to be of type (<<<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>), but received (<<>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>) instead.

尝试使用 tff.aggregators.federated_sample,其中 max_num_samples 等于您拥有的客户数量。

这应该是一个简单的 drop-in 替代您之前使用 tff.federated_collect 的方式。


在您的 accumulate 中,问题是您正在更改累加器将包含的张量数量,因此在累加多个累加器时会出现错误。但是,如果您想这样做,对于具有 k 元素的 rank-1 accumuland,您可能可以改为执行以下操作:

@tff.tf_computation(tff.types.TensorType(tf.float32, [None, k]),
                    tff.types.TensorType(tf.float32, [k]))
def accumulate(accumulator, accumuland):
  return tf.concat([accumulator, tf.expand_dims(accumuland, axis=0)], axis=0)