我如何使用 Dask 运行 对这个 "nested" 结构化数组进行计算?

How can I run computations on this "nested" structured array using Dask?

我正在尝试开始使用 dask。在下面的玩具示例中,我有三列,sitecountsreadingssitecounts 是标量列,而 readings 包含 3 维数组。

我可以 运行 在 counts 上进行计算,但如果我尝试 运行 在 readings 上进行计算,则会出现异常。我在这里正确使用 dask 吗?

import dask.array as da
import numpy as np
import tables

dtype = np.dtype([
    ('site', 'S1'),
    ('counts', np.int8),
    ('readings', np.float64, (2, 2, 3))
])

with tables.open_file('test.hdf5', 'w') as f:
    sensors = f.create_table('/', 'sensors', description=dtype)
    rows = [(site, count, np.random.rand(2, 2, 3))
            for count, site in enumerate('abcdefghij')]
    sensors.append(rows)
    sensors.flush()

    # Operating on 'counts' column works fine...
    x = da.from_array(f.root.sensors.cols.counts, chunks=5)
    x_add = (x + 1).compute()

    # But on 'readings' does not
    y = da.from_array(f.root.sensors.cols.readings, chunks=5)
    y_add = (y + 1).compute()

(y + 1).compute(),我得到以下异常。 (底部的实际错误似乎来自为 pytables 中的 TypeError 构建错误字符串,所以它不是很有帮助。)

TypeError                                 Traceback (most recent call last)
<ipython-input-115-77c7e132695c> in <module>()
     22     # But on readings column does not
     23     y = da.from_array(f.root.sensors.cols.readings, chunks=5)
---> 24     y_add = (y + 1).compute()

~/miniconda/lib/python3.6/site-packages/dask/base.py in compute(self, **kwargs)
     95             Extra keywords to forward to the scheduler ``get`` function.
     96         """
---> 97         (result,) = compute(self, traverse=False, **kwargs)
     98         return result
     99 

~/miniconda/lib/python3.6/site-packages/dask/base.py in compute(*args, **kwargs)
    202     dsk = collections_to_dsk(variables, optimize_graph, **kwargs)
    203     keys = [var._keys() for var in variables]
--> 204     results = get(dsk, keys, **kwargs)
    205 
    206     results_iter = iter(results)

~/miniconda/lib/python3.6/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, **kwargs)
     73     results = get_async(pool.apply_async, len(pool._pool), dsk, result,
     74                         cache=cache, get_id=_thread_get_id,
---> 75                         pack_exception=pack_exception, **kwargs)
     76 
     77     # Cleanup pools associated to dead threads

~/miniconda/lib/python3.6/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    519                         _execute_task(task, data)  # Re-execute locally
    520                     else:
--> 521                         raise_exception(exc, tb)
    522                 res, worker_id = loads(res_info)
    523                 state['cache'][key] = res

~/miniconda/lib/python3.6/site-packages/dask/compatibility.py in reraise(exc, tb)
     58         if exc.__traceback__ is not tb:
     59             raise exc.with_traceback(tb)
---> 60         raise exc
     61 
     62 else:

~/miniconda/lib/python3.6/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    288     try:
    289         task, data = loads(task_info)
--> 290         result = _execute_task(task, data)
    291         id = get_id()
    292         result = dumps((result, id))

~/miniconda/lib/python3.6/site-packages/dask/local.py in _execute_task(arg, cache, dsk)
    269         func, args = arg[0], arg[1:]
    270         args2 = [_execute_task(a, cache) for a in args]
--> 271         return func(*args2)
    272     elif not ishashable(arg):
    273         return arg

~/miniconda/lib/python3.6/site-packages/dask/array/core.py in getarray(a, b, lock)
     61         lock.acquire()
     62     try:
---> 63         c = a[b]
     64         if type(c) != np.ndarray:
     65             c = np.asarray(c)

~/miniconda/lib/python3.6/site-packages/tables/table.py in __getitem__(self, key)
   3455         else:
   3456             raise TypeError(
-> 3457                 "'%s' key type is not valid in this context" % key)
   3458 
   3459     def __iter__(self):

TypeError: not all arguments converted during string formatting

最后,x_add 的值为 array([ 6, 7, 8, 9, 10, 6, 7, 8, 9, 10], dtype=int8),在平铺两次的最后一个块上为 +1。我希望 [1, 2, ..., 10]。再一次,让我想知道我是否按预期使用了 dask。

Dask.array 期望给它的数组遵循 numpy 风格的切片。 PyTables 似乎不支持这个。

In [12]: f.root.sensors.cols.counts
Out[12]: /sensors.cols.counts (Column(10,), int8, idx=None)

In [13]: f.root.sensors.cols.counts[:]
Out[13]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int8)

In [14]: f.root.sensors.cols.readings
Out[14]: /sensors.cols.readings (Column(10, 2, 2, 3), float64, idx=None)

In [15]: f.root.sensors.cols.counts[:, :, :, :]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-15-5e59d077075e> in <module>()
----> 1 f.root.sensors.cols.counts[:, :, :, :]

/home/mrocklin/Software/anaconda/lib/python3.6/site-packages/tables/table.py in __getitem__(self, key)
   3453         else:
   3454             raise TypeError(
-> 3455                 "'%s' key type is not valid in this context" % key)
   3456 
   3457     def __iter__(self):

TypeError: not all arguments converted during string formatting

如果适合您,我建议您改用 h5py。我通常发现 h5py 比 PyTables 更明智。