对于大小为 (M, N) 的大于内存的 dask 数组:如何从 chunks=(1, N) 重新分块到 chunks=(M, 1)?

For a bigger-than-memory dask array of size=(M, N): How to re-chunk from chunks=(1, N) to chunks=(M, 1)?

例如,为了沿整个轴应用编码为 Numpy/Numba 的 IIR 滤波器,我需要将 size=(M, N) dask 数组从 chunks=(m0, n0) 重新分块到chunks=(m1, N)m1 < m0.

由于 Dask 避免重复任务,在 rechunk-split/rechunk-merge 期间,它将在内存中存储价值 (m0, N) (x 2?) 的数据。有没有办法优化图形来避免这种行为?

我知道在哪里可以找到手动优化 Dask 图的信息。但是有没有一种方法可以调整调度策略以允许重复任务或(自动)重新排列图形以最大限度地减少此重新分块期间的内存使用?

这是一个最小的例子(针对 chunks=(M, 1)chunks=(1, N) 的极端情况):

from dask import array as da
from dask.distributed import Client

# limit memory to 4 GB
client = Client(memory_limit=4e9)

# Create 80 GB random array with chunks=(M, 1)
arr = da.random.uniform(-1, 1, size=(1e5, 1e5), chunks=(1e5, 1))

# Compute mean (This works!)
arr.mean().compute()

# Rechunk to chunks=(1, N)
arr = arr.rechunk((1, 1e5))

# Compute mean (This hits memory limit!)
arr.mean().compute()

不幸的是,您处于最坏的情况下,您需要先计算每个输入块,然后才能获得单个输出块。

Dask 的重新分块操作很不错,他们会在此期间将东西重新分块成中间大小的块,所以这可能会在内存不足的情况下工作,但你肯定会写东西到磁盘。

总之,原则上没有什么是你应该额外做的。理论上,Dask 的重分块算法应该可以解决这个问题。如果你愿意,你可以使用 threshold=block_size_limit= 关键字来重新分块。

block_size_limit= 关键字导致某种解决方案。

(下面,我使用了一个较小的阵列,因为我没有剩余 80GB 的磁盘可以溢出。)

from dask import array as da
from dask.distributed import Client

# limit memory to 1 GB
client = Client(n_workers=1, threads_per_worker=1, memory_limit=1e9)

# Create 3.2 GB array
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))

# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph")  # 2000 nodes

# Compute
print(arr.mean().compute())  # Takes 11.9 seconds. Doesn't spill.

# re-create array and rechunk with block_size_limit=1e3
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4), block_size_limit=1e3)

# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph")  # 32539 nodes

# Compute
print(arr.mean().compute())  # Takes 140 seconds, spills ~5GB to disk.

# re-create array and rechunk with default kwargs
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4))

# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph")  # 9206 nodes

# Compute
print(arr.mean().compute())  # Worker dies at 95% memory use