执行堆栈中数据集类型的更改

Change of the dataset type in the execution stack

问题是在执行堆栈的不同点期间数据集从一种类型更改为另一种类型。例如,如果我添加一个新的数据集 class,其中包含更多感兴趣的成员属性(它继承自 ops.data.dataset_ops 中的 classes 之一,如 UnaryDataset),结果在稍后的执行点(client_update 函数),数据集被转换为 _VaraintDataset 类型,因此任何添加的属性都将丢失。所以问题是如何在执行过程中保留新定义的数据集class的成员属性。下面是 emnist 示例,其中类型从 ParallelMapDataset 更改为 _VariantDataset。

在training_utils.py第194行的函数client_dataset中,我修改为显示数据集的类型如下

  def client_datasets(round_num):
    sampled_clients = sample_clients_fn(round_num)
    sampled_client_datasets = []
    for client in sampled_clients:
        d =  train_dataset.create_tf_dataset_for_client(client)
        sampled_client_datasets.append(train_dataset.create_tf_dataset_for_client(client))
        tf.print('CLIENT DATASETS: ', d, type(d))
    return sampled_client_datasets

输出为:

CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>

然后在fed_avg_schedule.py第178行客户端调用的tf.functionclient_update中,数据集是不同类型的

@tf.function
  def client_update(model,
                    dataset,
                    initial_weights,
                    client_optimizer,
                    client_weight_fn=None):
    """Updates client model.

    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      initial_weights: A `tff.learning.Model.weights` from server.
      client_optimizer: A `tf.keras.optimizer.Optimizer` object.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.

    Returns:
      A 'ClientOutput`.
    """

    tf.print('CLIENT UPDATE: ', dataset, type(dataset))
    ....

输出将是:

CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>

我可能错了,但我做了一些跟踪,发现在某些时候调用了函数 (_to_components(self, value) of DatasetSpec) 进行转换:

  def _to_components(self, value):
    return value._variant_tensor  # pylint: disable=protected-access

编辑 - 按照建议的答案

以下是我在提取联合存储库的最新版本后对 simpel_fedavg 示例所做的更改

首先,我 add/modified 下面的行到 build_fed_avg_process of simple_fedavg_tff.py

server_message_type = server_message_fn.type_signature.result
  tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
  meta_data_type = tff.SequenceType(tf.string)

  @tff.tf_computation(tf_dataset_type, meta_data_type, server_message_type)
  def client_update_fn(tf_dataset, meta_data, server_message):
    model = model_fn()
    client_optimizer = client_optimizer_fn()
    return client_update(model, tf_dataset, meta_data, server_message, client_optimizer)

@tff.tf_computation((tf_dataset_type, meta_data_type))
  def extract_data_metadata_fn(tf_dataset_metadata_tuple):
    x, y = tf_dataset_metadata_tuple
    return x, y

  federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
  federated_dataset_type = tff.FederatedType( (tf_dataset_type, meta_data_type), tff.CLIENTS)
  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    data_set, meta_data = tff.federated_map(extract_data_metadata_fn, federated_dataset)

    #client_outputs = tff.federated_map(client_update_fn, (federated_dataset, server_message_at_client))
    client_outputs = tff.federated_map(client_update_fn, (data_set, meta_data, server_message_at_client))

在 simple_fedavg_tf.py 中,我添加了 meta_data

的以下打印行
@tf.function
def client_update(model, dataset, meta_data, server_message, client_optimizer):
  """Performans client local training of `model` on `dataset`.

  Args:
    model: A `tff.learning.Model`.
    dataset: A 'tf.data.Dataset'.
    server_message: A `BroadcastMessage` from server.
    client_optimizer: A `tf.keras.optimizers.Optimizer`.

  Returns:
    A 'ClientOutput`.
  """
  tf.print(meta_data)

  model_weights = model.weights
  initial_weights = server_message.model_weights
  client_ids = server_message.client_ids
  tff.utils.assign(model_weights, initial_weights)

在主文件emnist_simple_fedavg.py中,我修改了主函数中主训练循环的以下几行:

meta_data = ['a','b','c','d']
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, meta_data))

没有成功,我收到以下错误:

  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 176, in <module>
    app.run(main)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 166, in main
    server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, sampled_clients.tolist()))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/utils/function_utils.py", line 563, in __call__
    return context.invoke(self, arg)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
    return Retrying(*dargs, **dkw).call(f, *args, **kw)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
    return attempt.get(self._wrap_exception)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
    six.reraise(self.value[0], self.value[1], self.value[2])
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/six/__init__.py", line 693, in reraise
    raise value
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
    attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 215, in invoke
    _ingest(executor, unwrapped_arg, arg.type_signature)))
  File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
    return future.result()
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
    return await coro
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 99, in _ingest
    ingested = await asyncio.gather(*ingested)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 104, in _ingest
    return await executor.create_value(val, type_spec)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
    value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
    await cached_value.target_future
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
    self._target_executor.create_value(value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
    result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
    return await coro
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 383, in create_value
    return await self._strategy.compute_federated_value(value, type_spec)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federated_resolving_strategy.py", line 275, in compute_federated_value
    for v, c in zip(value, children)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 282, in create_value
    *[self.create_value(val, t) for (_, val), t in zip(v_el, type_spec)])
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
    value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
    await cached_value.target_future
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
    self._target_executor.create_value(value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
    result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
    return await coro
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 464, in create_value
    return EagerValue(value, self._tf_function_cache, type_spec, self._device)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 367, in __init__
    type_spec, device)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 335, in to_representation_for_type
    type_conversions.TF_DATASET_REPRESENTATION_TYPES)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
    type_string(type_spec), type_string(type(target))))
TypeError: Expected tensorflow.python.data.ops.dataset_ops.DatasetV2 or tensorflow.python.data.ops.dataset_ops.DatasetV1, found str.
E0721 23:53:29.388700 139706363909952 base_events.py:1285] Task was destroyed but it is pending!
task: <Task pending coro=<trace.<locals>.async_trace() running at /root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py:200> wait_for=<Future pending cb=[_chain_future.<locals>._call_check_cancel() at /usr/lib/python3.6/asyncio/futures.py:403, <TaskWakeupMethWrapper object at 0x7f0f7c07eca8>()]> cb=[<TaskWakeupMethWrapper object at 0x7f0f7c07e648>()]>

新数据集Pythonclass需要支持序列化。这是必要的,因为 TensorFlow Federated 被设计为 运行 在与编写计算的机器不一定相同的机器上(例如,在跨设备联合学习的情况下是智能手机)。这些机器可能不是 运行ning Python,因此不理解创建的新子 class,因此需要更新序列化层。然而,这是相当低级的,也许还有其他方法可以实现预期的目标。

走出困境:如果目标是为客户端提供元数据和数据集,则可能更容易更改 fed_avg_schedule.build_fed_avg_process 返回的迭代过程的函数签名以接受元组每个客户端的(数据集、元数据结构)。

目前下一次计算的签名是(在Custom Federated Algorithms, Part 1: Introduction to the Federated Core中引入的TFF类型shorthand):

(<ServerState@SERVER, Datasets@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)

ServerState的定义。DatasetMetrics由模型和数据集定义)

相反,我们可能想要一个如下所示的签名:

(<ServerState@SERVER, <Datasets, Metadata>@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)

为此,我们可以执行以下操作:

  1. run_one_round here 上的参数类型更新为 tf_dataset_type 和元数据结构的元组。
  2. 通过 tff.federated_map 调用 here
  3. 插入新参数
  4. client_update_fn 添加一个新参数 here

根据更新的信息和错误日志,我认为问题出在这部分:
iterative_process.next(server_state, (sampled_train_data, meta_data))

我猜你需要的是 next 的第二个参数是一个 iterable 模糊的 (sampled_train_data_element, meta_data_element) 元组 - 每个采样客户端一个元素。

所以这可以通过将其更改为
来实现 iterative_process.next(server_state, zip(sampled_train_data, meta_data))
或者如果那不起作用,也许这个?
iterative_process.next(server_state, list(zip(sampled_train_data, meta_data)))


此外,假设您希望 meta_data 是每个客户端的单个字符串,则应将 meta_data_type 更改为 tff.to_type(tf.string)tff.SequenceType 用于表示一般序列,例如数据集。