沿 dask 数组的轴应用函数

Applying a function along an axis of a dask array

我正在分析来自气候模型模拟的海洋温度数据,其中 4D 数据阵列(时间、深度、纬度、经度;在下面表示为 dask_array)通常具有 (6000, 31, 189) 的形状, 192) 和 ~25GB 的大小(因此我希望使用 dask;我一直在尝试使用 numpy 处理这些数组时遇到内存错误)。

我需要在每个级别/纬度/经度点沿时间轴拟合三次多项式并存储生成的 4 个系数。因此,我设置了 chunksize=(6000, 1, 1, 1),因此每个网格点都有一个单独的块。

这是我获取三次多项式系数的函数(time_axis 轴值是在别处定义的全局一维 numpy 数组):

def my_polyfit(data):    
    return numpy.polyfit(data.squeeze(), time_axis, 3)

(所以在这种情况下,numpy.polyfit returns 一个长度为 4 的列表)

这是我认为我需要将其应用于每个块的命令:

dask_array.map_blocks(my_polyfit, chunks=(4, 1, 1, 1), drop_axis=0, new_axis=0).compute()

因此时间轴现在消失了(因此 drop_axis=0)并且在它的位置有一个新的系数轴(长度为 4)。

当我 运行 这个命令时,我得到 IndexError: tuple index out of range,所以我想 where/how 我误解了 map_blocks 的用法?

我怀疑如果您的函数 returns 一个与它消耗的维度相同的数组,您的体验会更流畅。例如。您可以考虑按如下方式定义函数:

def my_polyfit(data):
    return np.polyfit(data.squeeze(), ...)[:, None, None, None]

那么你可以忽略new_axisdrop_axis位。

在性能方面,您可能还想考虑使用更大的块大小。如果每个块有 6000 个数字,您就有超过一百万个块,这意味着您花在调度上的时间可能比花在实际计算上的时间更多。通常,我拍摄的是几兆字节大小的块。当然,增加 chunksize 会导致映射函数变得更加复杂。

例子

In [1]: import dask.array as da

In [2]: import numpy as np

In [3]: def f(b):
    return np.polyfit(b.squeeze(), np.arange(5), 3)[:, None, None, None]
   ...: 

In [4]: x = da.random.random((5, 3, 3, 3), chunks=(5, 1, 1, 1))

In [5]: x.map_blocks(f, chunks=(4, 1, 1, 1)).compute()
Out[5]: 
array([[[[ -1.29058580e+02,   2.21410738e+02,   1.00721521e+01],
         [ -2.22469851e+02,  -9.14889627e+01,  -2.86405832e+02],
         [  1.40415805e+02,   3.58726232e+02,   6.47166710e+02]],
         ...

聚会有点晚了,但我认为这可以使用基于 Dask 新功能的替代答案。特别是,我们添加了 apply_along_axis, which behaves basically like NumPy's apply_along_axis 除了 Dask Arrays。这导致了更简单的语法。此外,它还避免了在将自定义函数应用于每个一维块之前重新分块数据的需要,并且对初始分块没有真正的要求,它试图在最终结果中保留(除了减少或替换的轴) .

In [1]: import dask.array as da

In [2]: import numpy as np

In [3]: def f(b):
   ...:     return np.polyfit(b, np.arange(len(b)), 3)
   ...: 

In [4]: x = da.random.random((5, 3, 3, 3), chunks=(5, 1, 1, 1))

In [5]: da.apply_along_axis(f, 0, x).compute()
Out[5]: 
array([[[[  2.13570599e+02,   2.28924503e+00,   6.16369231e+01],
         [  4.32000311e+00,   7.01462518e+01,  -1.62215514e+02],
         [  2.89466687e+02,  -1.35522215e+02,   2.86643721e+02]],
         ...