如何在此代码中提取 20% 的递减损失?

How can I extract 20% descending loss in this code?

我有以下代码:

def model_fn():
keras_model = create_keras_model()
 return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

在这段代码中,我想通过降序排列损失对应的20%的项目来求平均。

 #server select in the top20% clients
selected_clients_weights = clinet_select(client_weights)

如何提取客户排序的损失?

一个好的起点是查看目录 tensorflow_federated/python/examples/simple_fedavg/ 并了解联合平均是如何实现的。

要将此扩展为仅根据损失取前 20% 的平均值,需要两件事:

  1. 添加来自 client_update 函数的额外输出,在本例中为 loss 值。
  2. 替换tff.federated_mean aggregation with a call to tff.federated_collect. This will return a sequence. This could then be sorted (possibly by weight) and averaged inside a new tff.tf_computation decorated method that is applied to the result of tff.federated_collect with tff.federated_map.