对于大小为 (M, N) 的大于内存的 dask 数组:如何从 chunks=(1, N) 重新分块到 chunks=(M, 1)?
For a bigger-than-memory dask array of size=(M, N): How to re-chunk from chunks=(1, N) to chunks=(M, 1)?
例如,为了沿整个轴应用编码为 Numpy/Numba 的 IIR 滤波器,我需要将 size=(M, N)
dask 数组从 chunks=(m0, n0)
重新分块到chunks=(m1, N)
与 m1 < m0
.
由于 Dask 避免重复任务,在 rechunk-split/rechunk-merge 期间,它将在内存中存储价值 (m0, N)
(x 2?) 的数据。有没有办法优化图形来避免这种行为?
我知道在哪里可以找到手动优化 Dask 图的信息。但是有没有一种方法可以调整调度策略以允许重复任务或(自动)重新排列图形以最大限度地减少此重新分块期间的内存使用?
这是一个最小的例子(针对 chunks=(M, 1)
→ chunks=(1, N)
的极端情况):
from dask import array as da
from dask.distributed import Client
# limit memory to 4 GB
client = Client(memory_limit=4e9)
# Create 80 GB random array with chunks=(M, 1)
arr = da.random.uniform(-1, 1, size=(1e5, 1e5), chunks=(1e5, 1))
# Compute mean (This works!)
arr.mean().compute()
# Rechunk to chunks=(1, N)
arr = arr.rechunk((1, 1e5))
# Compute mean (This hits memory limit!)
arr.mean().compute()
不幸的是,您处于最坏的情况下,您需要先计算每个输入块,然后才能获得单个输出块。
Dask 的重新分块操作很不错,他们会在此期间将东西重新分块成中间大小的块,所以这可能会在内存不足的情况下工作,但你肯定会写东西到磁盘。
总之,原则上没有什么是你应该额外做的。理论上,Dask 的重分块算法应该可以解决这个问题。如果你愿意,你可以使用 threshold=
和 block_size_limit=
关键字来重新分块。
block_size_limit=
关键字导致某种解决方案。
(下面,我使用了一个较小的阵列,因为我没有剩余 80GB 的磁盘可以溢出。)
from dask import array as da
from dask.distributed import Client
# limit memory to 1 GB
client = Client(n_workers=1, threads_per_worker=1, memory_limit=1e9)
# Create 3.2 GB array
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph") # 2000 nodes
# Compute
print(arr.mean().compute()) # Takes 11.9 seconds. Doesn't spill.
# re-create array and rechunk with block_size_limit=1e3
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4), block_size_limit=1e3)
# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph") # 32539 nodes
# Compute
print(arr.mean().compute()) # Takes 140 seconds, spills ~5GB to disk.
# re-create array and rechunk with default kwargs
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4))
# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph") # 9206 nodes
# Compute
print(arr.mean().compute()) # Worker dies at 95% memory use
例如,为了沿整个轴应用编码为 Numpy/Numba 的 IIR 滤波器,我需要将 size=(M, N)
dask 数组从 chunks=(m0, n0)
重新分块到chunks=(m1, N)
与 m1 < m0
.
由于 Dask 避免重复任务,在 rechunk-split/rechunk-merge 期间,它将在内存中存储价值 (m0, N)
(x 2?) 的数据。有没有办法优化图形来避免这种行为?
我知道在哪里可以找到手动优化 Dask 图的信息。但是有没有一种方法可以调整调度策略以允许重复任务或(自动)重新排列图形以最大限度地减少此重新分块期间的内存使用?
这是一个最小的例子(针对 chunks=(M, 1)
→ chunks=(1, N)
的极端情况):
from dask import array as da
from dask.distributed import Client
# limit memory to 4 GB
client = Client(memory_limit=4e9)
# Create 80 GB random array with chunks=(M, 1)
arr = da.random.uniform(-1, 1, size=(1e5, 1e5), chunks=(1e5, 1))
# Compute mean (This works!)
arr.mean().compute()
# Rechunk to chunks=(1, N)
arr = arr.rechunk((1, 1e5))
# Compute mean (This hits memory limit!)
arr.mean().compute()
不幸的是,您处于最坏的情况下,您需要先计算每个输入块,然后才能获得单个输出块。
Dask 的重新分块操作很不错,他们会在此期间将东西重新分块成中间大小的块,所以这可能会在内存不足的情况下工作,但你肯定会写东西到磁盘。
总之,原则上没有什么是你应该额外做的。理论上,Dask 的重分块算法应该可以解决这个问题。如果你愿意,你可以使用 threshold=
和 block_size_limit=
关键字来重新分块。
block_size_limit=
关键字导致某种解决方案。
(下面,我使用了一个较小的阵列,因为我没有剩余 80GB 的磁盘可以溢出。)
from dask import array as da
from dask.distributed import Client
# limit memory to 1 GB
client = Client(n_workers=1, threads_per_worker=1, memory_limit=1e9)
# Create 3.2 GB array
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph") # 2000 nodes
# Compute
print(arr.mean().compute()) # Takes 11.9 seconds. Doesn't spill.
# re-create array and rechunk with block_size_limit=1e3
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4), block_size_limit=1e3)
# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph") # 32539 nodes
# Compute
print(arr.mean().compute()) # Takes 140 seconds, spills ~5GB to disk.
# re-create array and rechunk with default kwargs
arr = da.random.uniform(-1, 1, size=(2e4, 2e4), chunks=(2e4, 1e1))
arr = arr.rechunk((2e1, 2e4))
# Check graph size
print(len(arr.__dask_graph__()), "nodes in graph") # 9206 nodes
# Compute
print(arr.mean().compute()) # Worker dies at 95% memory use