在服务器 tensorflow federated 上访问和修改从客户端发送的权重

Access and modify weights sent from client on the server tensorflow federated

我正在使用 Tensorflow Federated,但我在读取客户端更新后尝试在服务器上执行某些操作时实际上遇到了一些问题。

这是函数

@tff.federated_computation(federated_server_state_type,
                           federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight


    tf.print(client_outputs.weights_delta)
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

我想打印 client_outputs.weights_delta 并对客户端在使用 tff.federated_mean 之前发送到服务器的权重进行一些操作,但我不知道该怎么做。

当我尝试打印时,我得到了这个

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

有什么方法可以修改这些元素吗?

我尝试使用 return client_outputs.weights_delta.comp 在主体中进行修改(我可以这样做),然后我尝试调用一种新方法来完成服务器更新的其余操作,但是错误是:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean' 其中 calculate_federated_mean 是我创建的新函数的名称。

这是主要的:

 for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

基于来自 github [Tensorflow Federated simple_fedavg][1] 的 simple_fedavg 项目作为基础项目。

编辑 1:

所以,感谢@Jakub Konecny,我取得了一些进步,但我发现了一个我实际上并不理解的新问题。

所以,如果我使用这个 client_update

@tf.function
def client_update(model, dataset, server_message, client_optimizer):
    """Performans client local training of `model` on `dataset`.
    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      server_message: A `BroadcastMessage` from server.
      client_optimizer: A `tf.keras.optimizers.Optimizer`.
    Returns:
      A 'ClientOutput`.
    """
    model_weights = model.weights
    initial_weights = server_message.model_weights
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)

    num_examples = tf.constant(0, dtype=tf.int32)
    loss_sum = tf.constant(0, dtype=tf.float32)
    # Explicit use `iter` for dataset is a trick that makes TFF more robust in
    # GPU simulation and slightly more performant in the unconventional usage
    # of large number of small datasets.
    for batch in iter(dataset):
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, model_weights.trainable)
        client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
        batch_size = tf.shape(batch['x'])[0]
        num_examples += batch_size
        loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)

    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          model_weights.trainable,
                                          initial_weights.trainable)


    client_weight = tf.cast(num_examples, tf.float32)

    import sparse_ternary_compression
    sparsification_rate = 1
    testing_new = []
    #TODO Da non applicare alle bias
    for tensor in weights_delta:
        testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))

    return ClientOutput(weights_delta, client_weight, loss_sum / client_weight, testing_new)

具有这些功能:

@tff.tf_computation
def stc_compression(original_tensor, sparsification_percentage):
    original_shape = tf.shape(original_tensor)
    tensor = tf.reshape(original_tensor, [-1])
    sparsification_percentage = tf.cast(sparsification_percentage, tf.float64)
    sparsification_rate = tf.size(tensor) / 100 * sparsification_percentage
    sparsification_rate = tf.cast(sparsification_rate, tf.int32)
    new_shape = tensor.get_shape().as_list()
    if sparsification_rate == 0:
        sparsification_rate = 1
    mask = tf.cast(tf.abs(tensor) >= tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    inv_mask = tf.cast(tf.abs(tensor) < tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    tensor_masked = tf.multiply(tensor, mask)
    sparsification_rate = tf.cast(sparsification_rate, tf.float32)
    average = tf.reduce_sum(tf.abs(tensor_masked)) / sparsification_rate
    compressed_tensor = tf.add(tf.multiply(average, mask) * tf.sign(tensor), tf.multiply(tensor_masked, inv_mask))
    negatives = tf.where(compressed_tensor < 0)
    positives = tf.where(compressed_tensor > 0)
    return negatives, positives, average, original_shape, new_shape

@tff.tf_computation
def stc_decompression(negatives, positives, average, original_shape, new_shape):
    decompressed_tensor = tf.zeros(new_shape, tf.float32)
    average_values_negative = tf.fill([tf.shape(negatives)[0], ], -average)
    average_values_positive = tf.fill([tf.shape(positives)[0], ], average)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, negatives, average_values_negative)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, positives, average_values_positive)
    decompressed_tensor = tf.reshape(decompressed_tensor, original_shape)
    return decompressed_tensor


@tff.tf_computation
def testing_new_list(list):
    testing = []
    for index in list:
        testing.append(
            stc_decompression(index[0], index[1],
                              index[2], index[3],
                              index[4]))

    return testing

run_one_round 函数中这样调用

@tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.
        Args:
          server_state: A `ServerState`.
          federated_dataset: A federated `tf.data.Dataset` with placement
            `tff.CLIENTS`.
        Returns:
          A tuple of updated `ServerState` and `tf.Tensor` of average loss.
        """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight

        import sparse_ternary_compression
        testing = tff.federated_map(sparse_ternary_compression.testing_new_list, client_outputs.test)

        # round_model_delta indica i pesi che vengono usati su server_update. Quindi è quello che va cambiato
        round_model_delta = tff.federated_mean(
            client_outputs.weights_delta, weight=weight_denom)

        server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

        return server_state, round_loss_metric, testing

但我得到这个例外

Traceback (most recent call last):
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 214, in <module>
    app.run(main)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 171, in main
    iterative_process = simple_fedavg_tff.build_federated_averaging_process(
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 95, in build_federated_averaging_process
    def client_update_fn(tf_dataset, server_message):
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 478, in __call__
    wrapped_func = self._strategy(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 216, in __call__
    result = fn_to_wrap(*args, **kwargs)
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 98, in client_update_fn
    return client_update(model, tf_dataset, server_message, client_optimizer)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 763, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3050, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.pyct.error_utils.KeyError: in user code:

        /mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tf.py:222 client_update  *
            testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:608 __call__  *
            return concrete_fn(packed_arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:525 __call__  *
            return context.invoke(self, arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py:54 invoke  *
            init_op, result = (
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py:1097 deserialize_and_call_tf_computation  *
            input_map = {
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3931 get_tensor_by_name  **
            return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3755 as_graph_element
            return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3795 _as_graph_element_locked
            raise KeyError("The name %s refers to a Tensor which does not "
    
        KeyError: "The name 'sub:0' refers to a Tensor which does not exist. The operation, 'sub', does not exist in the graph."
    
    
    Process finished with exit code 1

编辑 2:

通过将函数 stc_compressionstc_decompression 的装饰器从 tff.tf_computation 更改为 tf.function 解决了上述问题。现在似乎工作正常,因为如果我打印从 return server_state, round_loss_metric, testing 内部 run_one_round 获得的变量 testing,我从一开始就得到了我想要的权重。

我认为 我刚刚写的另一个问题也适用于此。

当您打印 client_outputs.weights_delta 时,您会得到另一个计算结果的抽象表示,这是 TFF 的主要内部实现细节。

使用 TensorFlow 代码编写一个 tff.tf_computation 修饰的方法,它会进行您需要的修改,然后使用 tff.federated_map 运算符从您尝试打印值的地方调用它。