Dask 循环库函数调用

Dask looping over library function call

目标

我想用在循环内使用 library 函数的 dask 并行化一个循环。此函数 mhw.detect() 计算 numpy 数组切片的一些统计信息。 None 数组的切片依赖于其他切片,所以我希望可以使用 dask 并行计算它们并将它们全部存储在同一个输出数组中。

代码

我正在处理的代码流程是:

import numpy as np
import marineHeatWaves as mhw
from dask import delayed

# Create fake input data
lat_size, long_size = 100, 100
data = np.random.random_integers(0, 30, size=(10_000, long_size, lat_size))  # size = (time, longitude, latitude)
time = np.arange(730_000, 740_000)  # time in ordinal days

# Initialize an empty array to hold the output
output_array = np.empty(data.shape)

# loop through each pixel in the data array
for idx_lat in range(lat_size):
    for idx_long in range(long_size):
        # Extract a slice of data
        data_slice = data[:, idx_lat, idx_long]
        # Use the library function to calculate the stats for the pixel
        # `library_output` is a dictionary that has a numpy array inside it
        _, library_output = delayed(mhw.detect)(time, data_slice)
        # Update the output array with the calculated values from the library
        output_array[:, idx_lat, idx_long] = library_output['seas']

之前的努力

当我 运行 此代码时,我收到错误 TypeError: Delayed objects of unspecified length are not iterableAnother stack overflow post 讨论了这个问题并通过将延迟函数的输出转换为延迟对象解决了这个问题。但是,因为我没有自己创建输出对象,所以我不确定是否可以将其转换为延迟对象。

我也试过将 da.from_delayed() 的最后一行换行,如 output_array[:, idx_lat, idx_long] = da.from_delayed(library_output['seas']) 并用 da.empty(data.shape) 初始化 output_array。不过,我得到了同样的错误,因为我认为代码没有通过库函数 delayed(mhw.detect)(time, data_slice).

是否可以将其并行化?这种要求 dask 并行计算所有切片并将它们放在输出数组中的方法是否合理?

完整追溯

TypeError                                 Traceback (most recent call last)
/home/rwegener/mhw-ocetrac-census/notebooks/ejoliver_subset_MUR.ipynb Cell 44' in <cell line: 10>()
     13 data_slice = data[:, idx_lat, idx_long]
     14 # Use the library function to calculate the stats for the pixel
---> 15 _, point_clim = delayed(mhw.detect)(time_ordinal, data_slice)
     16 # Update the output array with the calculated values from the library
     17 output_array[:, idx_lat, idx_long] = point_clim['seas']

File ~/.conda/envs/dask/lib/python3.10/site-packages/dask/delayed.py:581, in Delayed.__iter__(self)
    579 def __iter__(self):
    580     if self._length is None:
--> 581         raise TypeError("Delayed objects of unspecified length are not iterable")
    582     for i in range(self._length):
    583         yield self[i]

TypeError: Delayed objects of unspecified length are not iterable

更新

按照建议使用 .apply_along_axis()

# Create fake input data
lat_size, long_size = 100, 100
data = np.random.randint(0, 30, size=(10_000, long_size, lat_size))  # size = (time, longitude, latitude)
data = dask.array.from_array(data, chunks=(-1, 100, 100))
time = np.arange(730_000, 740_000)  # time in ordinal days

# Initialize an empty array to hold the output
output_array = np.empty(data.shape)

# define a wrapper to rearrange arguments
def func1d(arr, time, shape=(10000,)):
   print(arr.shape)
   return mhw.detect(time, arr)

res = dask.array.apply_along_axis(func1d, 0, data, time=time)

输出:

(1,)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/homes/metogra/rwegener/mhw-ocetrac-census/notebooks/ejoliver_subset_MUR.ipynb Cell 48' in <cell line: 15>()
     12    print(arr.shape)
     13    return mhw.detect(time, arr)
---> 15 res = dask.array.apply_along_axis(func1d, 0, data, time=time)

File ~/.conda/envs/dask/lib/python3.10/site-packages/dask/array/routines.py:508, in apply_along_axis(func1d, axis, arr, dtype, shape, *args, **kwargs)
    506 if shape is None or dtype is None:
    507     test_data = np.ones((1,), dtype=arr.dtype)
--> 508     test_result = np.array(func1d(test_data, *args, **kwargs))
    509     if shape is None:
    510         shape = test_result.shape

/homes/metogra/rwegener/mhw-ocetrac-census/notebooks/ejoliver_subset_MUR.ipynb Cell 48' in func1d(arr, time, shape)
     11 def func1d(arr, time, shape=(10000,)):
     12    print(arr.shape)
---> 13    return mhw.detect(time, arr)

File ~/.conda/envs/dask/lib/python3.10/site-packages/marineHeatWaves-0.28-py3.10.egg/marineHeatWaves.py:280, in detect(t, temp, climatologyPeriod, pctile, windowHalfWidth, smoothPercentile, smoothPercentileWidth, minDuration, joinAcrossGaps, maxGap, maxPadLength, coldSpells, alternateClimatology, Ly)
    278     tt = tt[tt>=0] # Reject indices "before" the first element
    279     tt = tt[tt<TClim] # Reject indices "after" the last element
--> 280     thresh_climYear[d-1] = np.nanpercentile(tempClim[tt.astype(int)], pctile)
    281     seas_climYear[d-1] = np.nanmean(tempClim[tt.astype(int)])
    282 # Special case for Feb 29

IndexError: index 115 is out of bounds for axis 0 with size 1

而不是使用延迟,这似乎是 dask.array 的好案例。

您可以通过对 numpy 数组进行分区来创建 dask 数组:

da = dask.array.from_array(output_array, chunks=(-1, 10, 10))

现在您可以在每个块中使用 dask.array.map_blocks alongside np.apply_along_axis 调用 mhw.detect

# define a wrapper to rearrange arguments
def func1d(arr, time):
   return mhw.detect(time, arr)

def block_func(block, **kwargs):
    return np.apply_along_axis(func1d, 0, block, **kwargs)

res = data.map_blocks(block_func, meta=data, time=time)
res = res.compute()

上面的 map_blocks 答案很有效!此外,apply_along_axis() 在评论中被建议和讨论。我能够使该方法起作用,但为了使其正常运行,您需要同时使用 dtypeshape 输入到 da.apply_along_axis()。如果没有提供这些,函数就无法计算出它应该作为参数传递的数据的形状。 所以,另一个解决方案:

import dask.array as da

# Create fake input data
lat_size, long_size = 100, 100
data = da.random.random_integers(0, 30, size=(1_000, long_size, lat_size), chunks=(-1, 10, 10))  # size = (time, longitude, latitude)
time = np.arange(730_000, 731_000)  # time in ordinal days

# define a wrapper to rearrange arguments
def func1d(arr, time):
   return mhw.detect(time, arr)

result = da.apply_along_axis(func1d, 0, data, time=time, dtype=data.dtype, shape=(1000,))
result.compute()