如何在 Federated Tensorflow 中绘制增量权重的直方图摘要?

How to plot Histogram summary for delta weight in Federated Tensorflow?

我正在分析我在 Tensorflow Federated 中使用 FedAvg 实施的方法。我需要为每个客户端传送到服务器的增量权重创建一个直方图。每个客户端分别调用 simulation/federated_avaraging.py,但问题是我无法在其中调用以下 API。 tf.summary.histogram()。任何帮助将不胜感激。

在TFF中,TensorFlow代表"local computation";因此,如果您需要一种方法来检查 客户端的内容,您将需要首先通过 TFF 聚合您想要的值,或者检查本机 [=28] 中的 returned 值=].

如果你想使用 TF 操作,我建议使用 tff.federated_collect 内部函数,"gather" 在服务器上你想要的所有值,然后 federated_map 一个 TF 函数,它采用这些值并生成您想要的可视化效果。

如果您更愿意在 Python 级别工作,这里有一个简单的选项(这是我会采用的方法):只需 return 来自您 tff.federated_computation 的客户培训结果;当您调用此计算时,这将具体化这些结果的 Python 列表,您可以根据需要将其可视化。这大致类似于:

@tff.federated_computation(...)
def train_one_round(...):
  ...
  trained_clients = run_training(...)
  new_model = update_global_model(trained_clients,...)
  return new_model, trained_clients

在此示例中,此函数将 return 一个元组,其第二个元素是一个 Python 列表,表示所有客户端的训练结果。