clients_per_round 超过 99 时 tff.templates.IterativeProcess.next() 中的 TensorFlow Federated (TFF) TypeError
TensorFlow Federated (TFF) TypeError in tff.templates.IterativeProcess.next() when clients_per_round exceed 99
我使用 TFF 实现了自定义联合学习 GAN 训练循环,类似于 this code by Google Research。
使用以下代码片段找到特定训练轮次的客户数据:
def client_dataset_fn():
# Sample clients and data
sampled_clients = np.random.choice(train_data.client_ids, size=cfg.clients_per_round, replace=False)
datasets = [(next(client_gen_inputs_iterator),
train_data.create_tf_dataset_for_client(client_id).take(cfg.n_critic))
for client_id in sampled_clients]
return datasets
client_noise_inputs, client_real_data = zip(*client_dataset_fn())
在将 cfg.clients_per_round
设置为 99 之前,它一直有效。当它设置为 100 或更大的值时(当然客户端总数更大),我收到以下错误:
Traceback (most recent call last):
File "main.py", line 109, in main
metrics = run_single_trial(train_data, test_data, cfg)
File "/mnt/workspace/tff/GAN/federated/fedgan_main.py", line 73, in run_single_trial
metrics = train_loop(iterative_process, server_dataset_fn, client_dataset_fn, model, eval_hook_fn, cfg)
File "/mnt/workspace/tff/GAN/federated/fedgan_main.py", line 124, in train_loop
client_real_data)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/computation/function_utils.py", line 525, in __call__
return context.invoke(self, arg)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
return Retrying(*dargs, **dkw).call(f, *args, **kw)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
return attempt.get(self._wrap_exception)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
six.reraise(self.value[0], self.value[1], self.value[2])
File "/usr/local/lib/python3.6/dist-packages/six.py", line 703, in reraise
raise value
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 226, in invoke
_ingest(executor, unwrapped_arg, arg.type_signature)))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 396, in _wrapped
return await coro
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 111, in _ingest
ingested = await asyncio.gather(*ingested)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 116, in _ingest
return await executor.create_value(val, type_spec)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 294, in create_value
value, type_spec))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 396, in _wrapped
return await coro
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 394, in create_value
return await self._strategy.compute_federated_value(value, type_spec)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/federated_composing_strategy.py", line 279, in compute_federated_value
py_typecheck.check_type(value, list)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
type_string(type_spec), type_string(type(target))))
TypeError: Expected list, found tuple.
在调试的时候,我查看了traceback最后一行的target
变量,发现就是上面提到的client_real_data
和client_noise_inputs
。它们的类型实际上是元组而不是列表,但是,这不会随着 cfg.clients_per_round
的不同数量而改变。 cfg.clients_per_round
的唯一用法在上面的随机选择中显示。
我真的无法解释为什么会这样,也许有人经历过类似的事情可以帮助我。
我使用的包版本如下:
- Python 3.6.9 或 3.8.10(都选中)
- 张量流 2.5.1
- tensorflow 联合 0.19.0
- 重试 1.3.3
- 六个 1.15.0
作为解决方法,我现在使用 list(tuple_var)
手动更改 client_noise_inputs
和 client_real_data
的数据类型,但我仍然很好奇为什么需要列表。
(从 original on GitHub 复制和粘贴)
在我看来,这似乎是 federated_composing_strategy
and the federated_resolving_strategy
之间的实现区别。 IIRC,默认情况下,在您达到 100 个客户端之前,我们不会将组合执行程序注入您的堆栈中——这将是这个令人兴奋的谜团的来源。
特别是,组合策略是根据以下假设编程的:传入的客户放置的值表示 as a list, whereas the resolving strategy codes against a much more flexible set 个容器。
将您的客户放置的值强制到列表中并不疯狂——我们还可以扩展组合执行器中客户放置值的允许表示,以匹配解析执行器中的表示,可能会将适当的逻辑拉到一个共享的地方 like here。我认为如果您愿意,我们将非常乐意接受它的贡献!
我使用 TFF 实现了自定义联合学习 GAN 训练循环,类似于 this code by Google Research。
使用以下代码片段找到特定训练轮次的客户数据:
def client_dataset_fn():
# Sample clients and data
sampled_clients = np.random.choice(train_data.client_ids, size=cfg.clients_per_round, replace=False)
datasets = [(next(client_gen_inputs_iterator),
train_data.create_tf_dataset_for_client(client_id).take(cfg.n_critic))
for client_id in sampled_clients]
return datasets
client_noise_inputs, client_real_data = zip(*client_dataset_fn())
在将 cfg.clients_per_round
设置为 99 之前,它一直有效。当它设置为 100 或更大的值时(当然客户端总数更大),我收到以下错误:
Traceback (most recent call last):
File "main.py", line 109, in main
metrics = run_single_trial(train_data, test_data, cfg)
File "/mnt/workspace/tff/GAN/federated/fedgan_main.py", line 73, in run_single_trial
metrics = train_loop(iterative_process, server_dataset_fn, client_dataset_fn, model, eval_hook_fn, cfg)
File "/mnt/workspace/tff/GAN/federated/fedgan_main.py", line 124, in train_loop
client_real_data)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/computation/function_utils.py", line 525, in __call__
return context.invoke(self, arg)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
return Retrying(*dargs, **dkw).call(f, *args, **kw)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
return attempt.get(self._wrap_exception)
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
six.reraise(self.value[0], self.value[1], self.value[2])
File "/usr/local/lib/python3.6/dist-packages/six.py", line 703, in reraise
raise value
File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 226, in invoke
_ingest(executor, unwrapped_arg, arg.type_signature)))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 396, in _wrapped
return await coro
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 111, in _ingest
ingested = await asyncio.gather(*ingested)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 116, in _ingest
return await executor.create_value(val, type_spec)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 294, in create_value
value, type_spec))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
self._target_executor.create_value(value, type_spec))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 396, in _wrapped
return await coro
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
result = await fn(*fn_args, **fn_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 394, in create_value
return await self._strategy.compute_federated_value(value, type_spec)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/federated_composing_strategy.py", line 279, in compute_federated_value
py_typecheck.check_type(value, list)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
type_string(type_spec), type_string(type(target))))
TypeError: Expected list, found tuple.
在调试的时候,我查看了traceback最后一行的target
变量,发现就是上面提到的client_real_data
和client_noise_inputs
。它们的类型实际上是元组而不是列表,但是,这不会随着 cfg.clients_per_round
的不同数量而改变。 cfg.clients_per_round
的唯一用法在上面的随机选择中显示。
我真的无法解释为什么会这样,也许有人经历过类似的事情可以帮助我。
我使用的包版本如下:
- Python 3.6.9 或 3.8.10(都选中)
- 张量流 2.5.1
- tensorflow 联合 0.19.0
- 重试 1.3.3
- 六个 1.15.0
作为解决方法,我现在使用 list(tuple_var)
手动更改 client_noise_inputs
和 client_real_data
的数据类型,但我仍然很好奇为什么需要列表。
(从 original on GitHub 复制和粘贴)
在我看来,这似乎是 federated_composing_strategy
and the federated_resolving_strategy
之间的实现区别。 IIRC,默认情况下,在您达到 100 个客户端之前,我们不会将组合执行程序注入您的堆栈中——这将是这个令人兴奋的谜团的来源。
特别是,组合策略是根据以下假设编程的:传入的客户放置的值表示 as a list, whereas the resolving strategy codes against a much more flexible set 个容器。
将您的客户放置的值强制到列表中并不疯狂——我们还可以扩展组合执行器中客户放置值的允许表示,以匹配解析执行器中的表示,可能会将适当的逻辑拉到一个共享的地方 like here。我认为如果您愿意,我们将非常乐意接受它的贡献!