在 dask.Array 任务图中嵌入 pre/post-compute 操作

embedding pre/post-compute operations in a dask.Array task graph

我有兴趣创建一个 dask.array.Array 来打开和关闭资源 before/after compute()。但是,我不想对最终用户将如何调用 compute 做任何假设,我想避免创建自定义 dask Array subclass 或代理对象,所以我试图将操作嵌入到数组底层的 __dask_graph__ 中。

(旁白:请暂时忽略关于在 dask 中使用有状态对象的警告,我知道风险,这个问题只是关于任务图操作)。

考虑以下 class,它模拟了一个文件 reader,它必须处于 open 状态才能读取块,否则会出现段错误.

import dask.array as da
import numpy as np

class FileReader:
    _open = True

    def open(self):
        self._open = True

    def close(self):
        self._open = False

    def to_dask(self) -> da.Array:
        return da.map_blocks(
            self._dask_block,
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )

    def _dask_block(self):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

如果文件保持打开状态,一切都很好,但如果关闭文件,从 to_dask 返回的 dask 数组将失败:

>>> t = FileReader()
>>> darr = t.to_dask()
>>> t.close()
>>> darr.compute()  # RuntimeError: Segfault!

当前任务图如下:

>>> list(darr.dask)
[
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 0, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 1, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 2, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 3, 0, 0)
]

本质上,我想在开头添加新任务,_dask_block 层依赖,并在结尾添加一个任务,这取决于 _dask_block

我尝试直接操作 da.map_blocksHighLevelGraph 输出来手动添加这些任务,但发现它们在计算优化期间被修剪了,因为 darr.__dask_keys__() 不包含我的键(同样,我想避免 subclassing 或要求最终用户使用特殊优化标志调用 compute)。

一个解决方案是确保传递给 map_blocks 的 _dask_block 函数始终打开和关闭底层资源...但我们假设 open/close 进程相对缓慢,有效地破坏了单机并行性的性能。所以我们只想要一个在开始时打开,在结束时关闭。

我可以通过在调用 map_blocks 时包含一个新密钥来稍微“作弊”以包含打开我的文件的任务,如下所示:

    ...
    
    # new method that will run at beginning of compute()
    def _pre_compute(self):
        was_open = self._open
        if not was_open:
            self.open()
        return was_open

    def to_dask(self) -> da.Array:
        # new task key
        pre_task = 'pre_compute-' + tokenize(self._pre_compute)
        arr = da.map_blocks(
            self._dask_block,
            pre_task,  # add key here so all chunks depend on it
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )
        # add key to HighLevelGraph
        arr.dask.layers[pre_task] = {pre_task: (self._pre_compute,)}
        return da.Array(arr.dask, arr.name, arr.chunks, arr.dtype)

    # add "mock" argument to map_blocks function
    def _dask_block(self, _):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

到目前为止一切顺利,没有更多 RuntimeError...但是现在我已经泄漏了文件句柄,因为最后没有关闭它。

然后我想要的是图表末尾的任务,它取决于 pre_task 的输出(即是否必须为此计算打开文件),并关闭如果必须打开文件。

这就是我卡住的地方,因为 post-compute 键将被优化器修剪...

有没有办法在不创建覆盖 __dask_postcompute____dask_keys__ 等方法的自定义数组子 class 的情况下执行此操作,或者要求最终用户在不进行优化的情况下调用计算?

这是一个非常有趣的问题。我认为您在编辑任务图以包括用于打开和关闭共享资源的任务方面走在正确的轨道上。但是手动图形操作非常繁琐且难以正确处理。

我认为完成你想要的最简单的方法是使用一些相对最近添加的实用程序来处理 dask.graph_manipulation 中的图形操作。特别是,我想我们想要 bind,它可以用来向 Dask 集合添加隐式依赖,以及 wait_for,它可以用来确保 dependents 的集合等待另一个不相关的集合。

我通过修改您的示例以使用这些实用程序来创建各种 to_dask() 是自动打开和关闭的:

import dask
import dask.array as da
import numpy as np
from dask.graph_manipulation import bind, checkpoint, wait_on


class FileReader:
    _open = False

    def open(self):
        self._open = True

    def close(self):
        self._open = False

    def to_dask(self) -> da.Array:
        # Delayed version of self.open
        @dask.delayed
        def open_resource():
            self.open()
        # Delayed version of self.close
        @dask.delayed
        def close_resource():
            self.close()
            
        opener = open_resource()
        arr = da.map_blocks(
            self._dask_block,
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )
        # Make sure the array is dependent on `opener`
        arr = bind(arr, opener)

        closer = close_resource()
        # Make sure the closer is dependent on the array being done
        closer = bind(closer, arr)
        # Make sure dependents of arr happen after `closer` is done
        arr, closer = wait_on(arr, closer)
        return arr

    def _dask_block(self):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

我发现查看操作前后的任务图很有趣。

之前是比较直接的分块数组:

但是在操作之后,你可以看到数组块依赖于 open_resource,然后这些块流入 close_resource,然后流入让数组块进入更广阔的世界: