执行堆栈中数据集类型的更改
Change of the dataset type in the execution stack
问题是在执行堆栈的不同点期间数据集从一种类型更改为另一种类型。例如,如果我添加一个新的数据集 class,其中包含更多感兴趣的成员属性(它继承自 ops.data.dataset_ops 中的 classes 之一,如 UnaryDataset),结果在稍后的执行点(client_update 函数),数据集被转换为 _VaraintDataset 类型,因此任何添加的属性都将丢失。所以问题是如何在执行过程中保留新定义的数据集class的成员属性。下面是 emnist 示例,其中类型从 ParallelMapDataset 更改为 _VariantDataset。
在training_utils.py第194行的函数client_dataset中,我修改为显示数据集的类型如下
def client_datasets(round_num):
sampled_clients = sample_clients_fn(round_num)
sampled_client_datasets = []
for client in sampled_clients:
d = train_dataset.create_tf_dataset_for_client(client)
sampled_client_datasets.append(train_dataset.create_tf_dataset_for_client(client))
tf.print('CLIENT DATASETS: ', d, type(d))
return sampled_client_datasets
输出为:
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
然后在fed_avg_schedule.py第178行客户端调用的tf.functionclient_update中,数据集是不同类型的
@tf.function
def client_update(model,
dataset,
initial_weights,
client_optimizer,
client_weight_fn=None):
"""Updates client model.
Args:
model: A `tff.learning.Model`.
dataset: A 'tf.data.Dataset'.
initial_weights: A `tff.learning.Model.weights` from server.
client_optimizer: A `tf.keras.optimizer.Optimizer` object.
client_weight_fn: Optional function that takes the output of
`model.report_local_outputs` and returns a tensor that provides the
weight in the federated average of model deltas. If not provided, the
default is the total number of examples processed on device.
Returns:
A 'ClientOutput`.
"""
tf.print('CLIENT UPDATE: ', dataset, type(dataset))
....
输出将是:
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
我可能错了,但我做了一些跟踪,发现在某些时候调用了函数 (_to_components(self, value) of DatasetSpec) 进行转换:
def _to_components(self, value):
return value._variant_tensor # pylint: disable=protected-access
编辑 - 按照建议的答案
以下是我在提取联合存储库的最新版本后对 simpel_fedavg 示例所做的更改
首先,我 add/modified 下面的行到 build_fed_avg_process of simple_fedavg_tff.py
server_message_type = server_message_fn.type_signature.result
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
meta_data_type = tff.SequenceType(tf.string)
@tff.tf_computation(tf_dataset_type, meta_data_type, server_message_type)
def client_update_fn(tf_dataset, meta_data, server_message):
model = model_fn()
client_optimizer = client_optimizer_fn()
return client_update(model, tf_dataset, meta_data, server_message, client_optimizer)
@tff.tf_computation((tf_dataset_type, meta_data_type))
def extract_data_metadata_fn(tf_dataset_metadata_tuple):
x, y = tf_dataset_metadata_tuple
return x, y
federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType( (tf_dataset_type, meta_data_type), tff.CLIENTS)
@tff.federated_computation(federated_server_state_type,
federated_dataset_type)
def run_one_round(server_state, federated_dataset):
"""Orchestration logic for one round of computation.
Args:
server_state: A `ServerState`.
federated_dataset: A federated `tf.data.Dataset` with placement
`tff.CLIENTS`.
Returns:
A tuple of updated `ServerState` and `tf.Tensor` of average loss.
"""
server_message = tff.federated_map(server_message_fn, server_state)
server_message_at_client = tff.federated_broadcast(server_message)
data_set, meta_data = tff.federated_map(extract_data_metadata_fn, federated_dataset)
#client_outputs = tff.federated_map(client_update_fn, (federated_dataset, server_message_at_client))
client_outputs = tff.federated_map(client_update_fn, (data_set, meta_data, server_message_at_client))
在 simple_fedavg_tf.py 中,我添加了 meta_data
的以下打印行
@tf.function
def client_update(model, dataset, meta_data, server_message, client_optimizer):
"""Performans client local training of `model` on `dataset`.
Args:
model: A `tff.learning.Model`.
dataset: A 'tf.data.Dataset'.
server_message: A `BroadcastMessage` from server.
client_optimizer: A `tf.keras.optimizers.Optimizer`.
Returns:
A 'ClientOutput`.
"""
tf.print(meta_data)
model_weights = model.weights
initial_weights = server_message.model_weights
client_ids = server_message.client_ids
tff.utils.assign(model_weights, initial_weights)
在主文件emnist_simple_fedavg.py中,我修改了主函数中主训练循环的以下几行:
meta_data = ['a','b','c','d']
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, meta_data))
没有成功,我收到以下错误:
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 176, in <module>
app.run(main)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 166, in main
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, sampled_clients.tolist()))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/utils/function_utils.py", line 563, in __call__
return context.invoke(self, arg)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
return Retrying(*dargs, **dkw).call(f, *args, **kw)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
return attempt.get(self._wrap_exception)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
six.reraise(self.value[0], self.value[1], self.value[2])
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/six/__init__.py", line 693, in reraise
raise value
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 215, in invoke
_ingest(executor, unwrapped_arg, arg.type_signature)))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 99, in _ingest
ingested = await asyncio.gather(*ingested)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 104, in _ingest
return await executor.create_value(val, type_spec)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
await cached_value.target_future
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 383, in create_value
return await self._strategy.compute_federated_value(value, type_spec)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federated_resolving_strategy.py", line 275, in compute_federated_value
for v, c in zip(value, children)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 282, in create_value
*[self.create_value(val, t) for (_, val), t in zip(v_el, type_spec)])
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
await cached_value.target_future
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 464, in create_value
return EagerValue(value, self._tf_function_cache, type_spec, self._device)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 367, in __init__
type_spec, device)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 335, in to_representation_for_type
type_conversions.TF_DATASET_REPRESENTATION_TYPES)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
type_string(type_spec), type_string(type(target))))
TypeError: Expected tensorflow.python.data.ops.dataset_ops.DatasetV2 or tensorflow.python.data.ops.dataset_ops.DatasetV1, found str.
E0721 23:53:29.388700 139706363909952 base_events.py:1285] Task was destroyed but it is pending!
task: <Task pending coro=<trace.<locals>.async_trace() running at /root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py:200> wait_for=<Future pending cb=[_chain_future.<locals>._call_check_cancel() at /usr/lib/python3.6/asyncio/futures.py:403, <TaskWakeupMethWrapper object at 0x7f0f7c07eca8>()]> cb=[<TaskWakeupMethWrapper object at 0x7f0f7c07e648>()]>
新数据集Pythonclass需要支持序列化。这是必要的,因为 TensorFlow Federated 被设计为 运行 在与编写计算的机器不一定相同的机器上(例如,在跨设备联合学习的情况下是智能手机)。这些机器可能不是 运行ning Python,因此不理解创建的新子 class,因此需要更新序列化层。然而,这是相当低级的,也许还有其他方法可以实现预期的目标。
走出困境:如果目标是为客户端提供元数据和数据集,则可能更容易更改 fed_avg_schedule.build_fed_avg_process
返回的迭代过程的函数签名以接受元组每个客户端的(数据集、元数据结构)。
目前下一次计算的签名是(在Custom Federated Algorithms, Part 1: Introduction to the Federated Core中引入的TFF类型shorthand):
(<ServerState@SERVER, Datasets@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)
(ServerState
的定义。Dataset
和Metrics
由模型和数据集定义)
相反,我们可能想要一个如下所示的签名:
(<ServerState@SERVER, <Datasets, Metadata>@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)
为此,我们可以执行以下操作:
根据更新的信息和错误日志,我认为问题出在这部分:
iterative_process.next(server_state, (sampled_train_data, meta_data))
我猜你需要的是 next
的第二个参数是一个 iterable
模糊的 (sampled_train_data_element, meta_data_element)
元组 - 每个采样客户端一个元素。
所以这可以通过将其更改为
来实现
iterative_process.next(server_state, zip(sampled_train_data, meta_data))
或者如果那不起作用,也许这个?
iterative_process.next(server_state, list(zip(sampled_train_data, meta_data)))
此外,假设您希望 meta_data
是每个客户端的单个字符串,则应将 meta_data_type
更改为 tff.to_type(tf.string)
。 tff.SequenceType
用于表示一般序列,例如数据集。
问题是在执行堆栈的不同点期间数据集从一种类型更改为另一种类型。例如,如果我添加一个新的数据集 class,其中包含更多感兴趣的成员属性(它继承自 ops.data.dataset_ops 中的 classes 之一,如 UnaryDataset),结果在稍后的执行点(client_update 函数),数据集被转换为 _VaraintDataset 类型,因此任何添加的属性都将丢失。所以问题是如何在执行过程中保留新定义的数据集class的成员属性。下面是 emnist 示例,其中类型从 ParallelMapDataset 更改为 _VariantDataset。
在training_utils.py第194行的函数client_dataset中,我修改为显示数据集的类型如下
def client_datasets(round_num):
sampled_clients = sample_clients_fn(round_num)
sampled_client_datasets = []
for client in sampled_clients:
d = train_dataset.create_tf_dataset_for_client(client)
sampled_client_datasets.append(train_dataset.create_tf_dataset_for_client(client))
tf.print('CLIENT DATASETS: ', d, type(d))
return sampled_client_datasets
输出为:
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS: <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
然后在fed_avg_schedule.py第178行客户端调用的tf.functionclient_update中,数据集是不同类型的
@tf.function
def client_update(model,
dataset,
initial_weights,
client_optimizer,
client_weight_fn=None):
"""Updates client model.
Args:
model: A `tff.learning.Model`.
dataset: A 'tf.data.Dataset'.
initial_weights: A `tff.learning.Model.weights` from server.
client_optimizer: A `tf.keras.optimizer.Optimizer` object.
client_weight_fn: Optional function that takes the output of
`model.report_local_outputs` and returns a tensor that provides the
weight in the federated average of model deltas. If not provided, the
default is the total number of examples processed on device.
Returns:
A 'ClientOutput`.
"""
tf.print('CLIENT UPDATE: ', dataset, type(dataset))
....
输出将是:
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE: <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
我可能错了,但我做了一些跟踪,发现在某些时候调用了函数 (_to_components(self, value) of DatasetSpec) 进行转换:
def _to_components(self, value):
return value._variant_tensor # pylint: disable=protected-access
编辑 - 按照建议的答案
以下是我在提取联合存储库的最新版本后对 simpel_fedavg 示例所做的更改
首先,我 add/modified 下面的行到 build_fed_avg_process of simple_fedavg_tff.py
server_message_type = server_message_fn.type_signature.result
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
meta_data_type = tff.SequenceType(tf.string)
@tff.tf_computation(tf_dataset_type, meta_data_type, server_message_type)
def client_update_fn(tf_dataset, meta_data, server_message):
model = model_fn()
client_optimizer = client_optimizer_fn()
return client_update(model, tf_dataset, meta_data, server_message, client_optimizer)
@tff.tf_computation((tf_dataset_type, meta_data_type))
def extract_data_metadata_fn(tf_dataset_metadata_tuple):
x, y = tf_dataset_metadata_tuple
return x, y
federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType( (tf_dataset_type, meta_data_type), tff.CLIENTS)
@tff.federated_computation(federated_server_state_type,
federated_dataset_type)
def run_one_round(server_state, federated_dataset):
"""Orchestration logic for one round of computation.
Args:
server_state: A `ServerState`.
federated_dataset: A federated `tf.data.Dataset` with placement
`tff.CLIENTS`.
Returns:
A tuple of updated `ServerState` and `tf.Tensor` of average loss.
"""
server_message = tff.federated_map(server_message_fn, server_state)
server_message_at_client = tff.federated_broadcast(server_message)
data_set, meta_data = tff.federated_map(extract_data_metadata_fn, federated_dataset)
#client_outputs = tff.federated_map(client_update_fn, (federated_dataset, server_message_at_client))
client_outputs = tff.federated_map(client_update_fn, (data_set, meta_data, server_message_at_client))
在 simple_fedavg_tf.py 中,我添加了 meta_data
的以下打印行@tf.function
def client_update(model, dataset, meta_data, server_message, client_optimizer):
"""Performans client local training of `model` on `dataset`.
Args:
model: A `tff.learning.Model`.
dataset: A 'tf.data.Dataset'.
server_message: A `BroadcastMessage` from server.
client_optimizer: A `tf.keras.optimizers.Optimizer`.
Returns:
A 'ClientOutput`.
"""
tf.print(meta_data)
model_weights = model.weights
initial_weights = server_message.model_weights
client_ids = server_message.client_ids
tff.utils.assign(model_weights, initial_weights)
在主文件emnist_simple_fedavg.py中,我修改了主函数中主训练循环的以下几行:
meta_data = ['a','b','c','d']
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, meta_data))
没有成功,我收到以下错误:
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 176, in <module>
app.run(main)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 166, in main
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, sampled_clients.tolist()))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/utils/function_utils.py", line 563, in __call__
return context.invoke(self, arg)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
return Retrying(*dargs, **dkw).call(f, *args, **kw)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
return attempt.get(self._wrap_exception)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
six.reraise(self.value[0], self.value[1], self.value[2])
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/six/__init__.py", line 693, in reraise
raise value
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 215, in invoke
_ingest(executor, unwrapped_arg, arg.type_signature)))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 99, in _ingest
ingested = await asyncio.gather(*ingested)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 104, in _ingest
return await executor.create_value(val, type_spec)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
await cached_value.target_future
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 383, in create_value
return await self._strategy.compute_federated_value(value, type_spec)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federated_resolving_strategy.py", line 275, in compute_federated_value
for v, c in zip(value, children)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 282, in create_value
*[self.create_value(val, t) for (_, val), t in zip(v_el, type_spec)])
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
await cached_value.target_future
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
return await coro
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 464, in create_value
return EagerValue(value, self._tf_function_cache, type_spec, self._device)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 367, in __init__
type_spec, device)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 335, in to_representation_for_type
type_conversions.TF_DATASET_REPRESENTATION_TYPES)
File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
type_string(type_spec), type_string(type(target))))
TypeError: Expected tensorflow.python.data.ops.dataset_ops.DatasetV2 or tensorflow.python.data.ops.dataset_ops.DatasetV1, found str.
E0721 23:53:29.388700 139706363909952 base_events.py:1285] Task was destroyed but it is pending!
task: <Task pending coro=<trace.<locals>.async_trace() running at /root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py:200> wait_for=<Future pending cb=[_chain_future.<locals>._call_check_cancel() at /usr/lib/python3.6/asyncio/futures.py:403, <TaskWakeupMethWrapper object at 0x7f0f7c07eca8>()]> cb=[<TaskWakeupMethWrapper object at 0x7f0f7c07e648>()]>
新数据集Pythonclass需要支持序列化。这是必要的,因为 TensorFlow Federated 被设计为 运行 在与编写计算的机器不一定相同的机器上(例如,在跨设备联合学习的情况下是智能手机)。这些机器可能不是 运行ning Python,因此不理解创建的新子 class,因此需要更新序列化层。然而,这是相当低级的,也许还有其他方法可以实现预期的目标。
走出困境:如果目标是为客户端提供元数据和数据集,则可能更容易更改 fed_avg_schedule.build_fed_avg_process
返回的迭代过程的函数签名以接受元组每个客户端的(数据集、元数据结构)。
目前下一次计算的签名是(在Custom Federated Algorithms, Part 1: Introduction to the Federated Core中引入的TFF类型shorthand):
(<ServerState@SERVER, Datasets@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)
(ServerState
的定义。Dataset
和Metrics
由模型和数据集定义)
相反,我们可能想要一个如下所示的签名:
(<ServerState@SERVER, <Datasets, Metadata>@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)
为此,我们可以执行以下操作:
根据更新的信息和错误日志,我认为问题出在这部分:
iterative_process.next(server_state, (sampled_train_data, meta_data))
我猜你需要的是 next
的第二个参数是一个 iterable
模糊的 (sampled_train_data_element, meta_data_element)
元组 - 每个采样客户端一个元素。
所以这可以通过将其更改为
来实现
iterative_process.next(server_state, zip(sampled_train_data, meta_data))
或者如果那不起作用,也许这个?
iterative_process.next(server_state, list(zip(sampled_train_data, meta_data)))
此外,假设您希望 meta_data
是每个客户端的单个字符串,则应将 meta_data_type
更改为 tff.to_type(tf.string)
。 tff.SequenceType
用于表示一般序列,例如数据集。