具有核外计算的 dask 中的循环引用

Circular references in dask with out-of-core computations

我正在编写一个实现各种非核心算法的库,并且运行遇到了这样一个问题,即可以通过源构建一个循环依赖的计算图并将其存储到同一个内存映射对象中。例如

import numpy as np
import dask.array as da

array_shape = (8000, 8000)
chunk_shape = (100, 100)

### create the example data
adisk = np.memmap('/tmp/a.npy', dtype=np.float32, shape=array_shape)
bdisk = np.memmap('/tmp/b.npy', dtype=np.float32, shape=array_shape)

a = da.ones(mainshape, chunks=chunkshape, dtype=np.float32) 
b = 2*da.ones(mainshape, chunks=chunkshape, dtype=np.float32) 

a.store(adisk)
b.store(bdisk)
adisk.flush()
bdisk.flush()

### Begin demonstration of issue

c = da.from_array(adisk, chunks=chunkshape)
d = da.from_array(bdisk, chunks=chunkshape)

e = c@d

e.store(adisk)

# Assert fails because source data is overwritten before re-read
assert np.all(adisk[:] == adisk[0,0])

我的测试表明,如果数据太大而无法缓存在内存中,则 dask 将完成操作,但行为未定义,因为源数据可能会在结果数据被重用于计算的其他部分之前被结果数据覆盖。因此,上面的代码在一定的(机器相关的)矩阵大小以上不再产生正确的结果。

dask 是否提供任何方法来帮助检测和减轻这种循环依赖?

我研究了潜在的解决方案,我目前认为插件功能在这里可能有用,以检查存储操作的目标是否与链中更远的任何来源不同。

我相信 dask.array 通过文件名和最后修改时间标记内存映射对象。像这样的尚未编写的案例看起来像是失败案例。

我预计这会在任务优化阶段的某处失败。

短期我建议明确向 da.from_array 函数提供一个 name= 关键字以提供您自己的唯一 ID。

我为自己的目的创建了一个解决方案,可用于检查是否存在循环依赖关系。在我的例子中,由于代码结构,我可以很容易地保护潜在的问题,例如 assert no_circular_dependencies(e, adisk) 但一般来说,如果没有 Monkeypatching Dask,这将很难实现。

因为每个 dask 数组都保留了任务图中所有任务的列表,所以很容易检查图形以找到原始数据源并确保它们与目标数组不同。

def get_true_base(ndarr):
    if not isinstance(ndarr, np.ndarray):
        raise TypeError("expected numpy array")
    base = ndarr
    while True:
        try:
            base = base.base
        except AttributeError:
            break
    return base

def get_memmap_bases(dsk):
    # better to build a set of names and then grab arrays later
    nameset = set()
    for name, task in dsk.items():
        if isinstance(task[0], np.memmap):
            nameset.add(name)

    bases = [get_true_base(dsk[name][0]) for name in nameset]
    return bases 

def no_circular_dependency(dask_arr, target):
    if get_true_base(target) in get_memmap_bases(dask_arr.dask):
        return False
    return True