dask,对每个工人执行不可序列化的对象

dask, execute non-serializable object on every worker

我正在尝试执行以下图表:

由以下代码生成:

energies = [10, 20]
system = delayed(make_non_serializable_oject)(x=1)
trans = [delayed(some_function_that_uses_system)(system, energy) for energy in energies]
result = delayed(list)(trans)
result.visualize()

当我调用 result.compute() 时,计算永远不会完成。

调用 result.compute(get=dask.async.get_sync)result.compute(dask.threaded.get) 都有效。但是 result.compute(dask.multiprocessing.get) 不会并生成以下错误:

---------------------------------------------------------------------------
RemoteError                               Traceback (most recent call last)
<ipython-input-70-b5c8f2a1c6f6> in <module>()
----> 1 result.compute(get=dask.multiprocessing.get)

/home/bnijholt/anaconda3/lib/python3.5/site-packages/dask/base.py in compute(self, **kwargs)
     76             Extra keywords to forward to the scheduler ``get`` function.
     77         """
---> 78         return compute(self, **kwargs)[0]
     79 
     80     @classmethod

/home/bnijholt/anaconda3/lib/python3.5/site-packages/dask/base.py in compute(*args, **kwargs)
    169         dsk = merge(var.dask for var in variables)
    170     keys = [var._keys() for var in variables]
--> 171     results = get(dsk, keys, **kwargs)
    172 
    173     results_iter = iter(results)

/home/bnijholt/anaconda3/lib/python3.5/site-packages/dask/multiprocessing.py in get(dsk, keys, num_workers, func_loads, func_dumps, optimize_graph, **kwargs)
     81         # Run
     82         result = get_async(apply_async, len(pool._pool), dsk3, keys,
---> 83                            queue=queue, get_id=_process_get_id, **kwargs)
     84     finally:
     85         if cleanup:

/home/bnijholt/anaconda3/lib/python3.5/site-packages/dask/async.py in get_async(apply_async, num_workers, dsk, result, cache, queue, get_id, raise_on_exception, rerun_exceptions_locally, callbacks, **kwargs)
    479                     _execute_task(task, data)  # Re-execute locally
    480                 else:
--> 481                     raise(remote_exception(res, tb))
    482             state['cache'][key] = res
    483             finish_task(dsk, key, state, results, keyorder.get)

RemoteError: 
---------------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/bnijholt/anaconda3/lib/python3.5/multiprocessing/managers.py", line 228, in serve_client
    request = recv()
  File "/home/bnijholt/anaconda3/lib/python3.5/multiprocessing/connection.py", line 251, in recv
    return ForkingPickler.loads(buf.getbuffer())
  File "kwant/graph/core.pyx", line 664, in kwant.graph.core.CGraph_malloc.__cinit__ (kwant/graph/core.c:8330)
TypeError: __cinit__() takes exactly 6 positional arguments (0 given)
---------------------------------------------------------------------------

Traceback
---------
  File "/home/bnijholt/anaconda3/lib/python3.5/site-packages/dask/async.py", line 273, in execute_task
    queue.put(result)
  File "<string>", line 2, in put
  File "/home/bnijholt/anaconda3/lib/python3.5/multiprocessing/managers.py", line 732, in _callmethod
    raise convert_to_error(kind, result)

使用 ipyparallel 我会在每个引擎上执行 make_non_serializable_oject,这解决了那个案例的问题。

我想使用 dask 进行并行计算,我该如何解决这个问题?

确保您的数据可以序列化

您回溯中的这段代码显示您的 kwant 库中的对象没有很好地序列化它们自己:

Traceback (most recent call last):
  File "/home/bnijholt/anaconda3/lib/python3.5/multiprocessing/managers.py", line 228, in serve_client
    request = recv()
  File "/home/bnijholt/anaconda3/lib/python3.5/multiprocessing/connection.py", line 251, in recv
    return ForkingPickler.loads(buf.getbuffer())
  File "kwant/graph/core.pyx", line 664, in kwant.graph.core.CGraph_malloc.__cinit__ (kwant/graph/core.c:8330)
TypeError: __cinit__() takes exactly 6 positional arguments (0 given)

这就是多处理和分布式调度程序失败的原因。 Dask 需要序列化数据的能力,以便在不同进程之间移动数据。

解决此问题的最简单、最干净的方法是改进数据的序列化。理想情况下,您可以通过改进 kwant 来做到这一点。您也可以通过 dask 的自定义序列化来管理它,但目前可能需要更多工作。

将数据保存在一个位置

好的,让我们假设您无法改进序列化并且需要将数据保留在原处。这将限制您使用令人尴尬的并行工作流(地图)。有两种解决方案:

  1. 使用fuse优化`
  2. 明确跟踪任务 运行

保险丝

您将创建一些不可序列化的数据,然后 运行 数据,然后 运行 对其进行计算,在尝试将其移回之前将其变成可序列化的数据。这很好,只要调度程序决定永远不自行移动数据。您可以通过将所有这些任务融合到一个原子任务中来强制执行此操作。详情见optimization docs

from dask.optimize import fuse
bad_data = [f(...) for ...]
good_data = [convert_to_serializable_data(bd) for bd in bad_data]
dask.compute(good_data, optimizations=[fuse])

自己准确指定每个计算的位置

Data Locality docs