在 `xr.Dataset` xarray 中有效地屏蔽和计算多个组的均值

Efficiently mask and calculate means for multiple groups in `xr.Dataset` xarray

我有两个 xr.Dataset 对象。一个是一些变量的连续映射(这里 precipitation)。另一个是一组区域的分类图 ['region_1', 'region_2', 'region_3', 'region_4'].

我想通过 region/time 掩码计算每个 region 在每个 timestep 的平均值 precip,然后输出如下所示的数据帧。

In [6]: df.head()
Out[6]:
    datetime region_name          mean_value
0 2008-01-31    region_1   51.77333333333333
1 2008-02-29    region_1   44.87555555555556
2 2008-03-31    region_1   50.88444444444445
3 2008-04-30    region_1   48.50666666666667
4 2008-05-31    region_1  47.653333333333336

我有一些代码,但它对真实数据集的运行速度非常慢。谁能帮我优化一下?

一个最小可重现的例子

初始化我们的对象,两个相同形状的变量。 region 对象将从 shapefile 中读取,并且将具有两个以上的区域。

import xarray as xr
import pandas as pd
import numpy as np

def make_dataset(
    variable_name='precip',
    size=(30, 30),
    start_date='2008-01-01',
    end_date='2010-01-01',
    lonmin=-180.0,
    lonmax=180.0,
    latmin=-55.152,
    latmax=75.024,
):
    # create 2D lat/lon dimension
    lat_len, lon_len = size
    longitudes = np.linspace(lonmin, lonmax, lon_len)
    latitudes = np.linspace(latmin, latmax, lat_len)
    dims = ["lat", "lon"]
    coords = {"lat": latitudes, "lon": longitudes}

    # add time dimension
    times = pd.date_range(start_date, end_date, name="time", freq="M")
    size = (len(times), size[0], size[1])
    dims.insert(0, "time")
    coords["time"] = times

    # create values
    var = np.random.randint(100, size=size)

    return xr.Dataset({variable_name: (dims, var)}, coords=coords), size

ds, size = make_dataset()

# create dummy regions (not contiguous but doesn't matter for this example)
region_ds = xr.ones_like(ds).rename({'precip': 'region'})
array = np.random.choice([0, 1, 2, 3], size=size)
region_ds = region_ds * array

# create a dictionary explaining what the regions area
region_lookup = {
    0: 'region_1',
    1: 'region_2',
    2: 'region_3',
    3: 'region_4',
}

这些物体长什么样?

In[]: ds

Out[]:
<xarray.Dataset>
Dimensions:  (lat: 30, lon: 30, time: 24)
Coordinates:
  * lat      (lat) float64 -55.15 -50.66 -46.17 -41.69 ... 66.05 70.54 75.02
  * lon      (lon) float64 -180.0 -167.6 -155.2 -142.8 ... 155.2 167.6 180.0
  * time     (time) datetime64[ns] 2008-01-31 2008-02-29 ... 2009-12-31
Data variables:
    precip   (time, lat, lon) int64 51 92 14 71 60 20 82 ... 16 33 34 98 23 53

In[]: region_ds

Out[]:
<xarray.Dataset>
Dimensions:  (lat: 30, lon: 30, time: 24)
Coordinates:
  * lat      (lat) float64 -55.15 -50.66 -46.17 -41.69 ... 66.05 70.54 75.02
  * time     (time) datetime64[ns] 2008-01-31 2008-02-29 ... 2009-12-31
  * lon      (lon) float64 -180.0 -167.6 -155.2 -142.8 ... 155.2 167.6 180.0
Data variables:
    region   (time, lat, lon) float64 0.0 0.0 0.0 0.0 0.0 ... 1.0 1.0 1.0 1.0

当前实施

为了每次计算 region_ds 中每个区域 ['region_1', 'region_2', ...]ds 中变量的平均值,我需要遍历 TIME 和 REGION .

我遍历每个 REGION,然后遍历 da 对象中的每个 TIMESTEP。随着数据集变大(更多像素和更多时间步),此操作非常慢。有没有人能想到更高效/矢量化的实现。

对于我需要的所有区域和时间,我当前的实施速度都非常慢。是否有更有效地使用 numpy / xarray 的方法可以更快地获得我想要的结果?

def drop_nans_and_flatten(dataArray: xr.DataArray) -> np.ndarray:
    """flatten the array and drop nans from that array. Useful for plotting histograms.

    Arguments:
    ---------
    : dataArray (xr.DataArray)
        the DataArray of your value you want to flatten
    """
    # drop NaNs and flatten
    return dataArray.values[~np.isnan(dataArray.values)]

#
da = ds.precip
region_da = region_ds.region
valid_region_ids = [k for k in region_lookup.keys()]

# initialise empty lists
region_names = []
datetimes = []
mean_values = []

for valid_region_id in valid_region_ids:
    for time in da.time.values:
        region_names.append(region_lookup[valid_region_id])
        datetimes.append(time)
        # extract all non-nan values for that time-region
        mean_values.append(
            da.sel(time=time).where(region_da == valid_region_id).mean().values
        )

df = pd.DataFrame(
    {
        "datetime": datetimes,
        "region_name": region_names,
         "mean_value": mean_values,
    }
)

输出:

In [6]: df.head()
Out[6]:
    datetime region_name          mean_value
0 2008-01-31    region_1   51.77333333333333
1 2008-02-29    region_1   44.87555555555556
2 2008-03-31    region_1   50.88444444444445
3 2008-04-30    region_1   48.50666666666667
4 2008-05-31    region_1  47.653333333333336

In [7]: df.tail()
Out[7]:
     datetime region_name          mean_value
43 2009-08-31    region_4   50.83111111111111
44 2009-09-30    region_4   48.40888888888889
45 2009-10-31    region_4   51.56148148148148
46 2009-11-30    region_4  48.961481481481485
47 2009-12-31    region_4   48.36296296296296

In [20]: df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 96 entries, 0 to 95
Data columns (total 3 columns):
datetime       96 non-null datetime64[ns]
region_name    96 non-null object
mean_value     96 non-null object
dtypes: datetime64[ns](1), object(2)
memory usage: 2.4+ KB

In [21]: df.describe()
Out[21]:
                   datetime region_name         mean_value
count                    96          96                 96
unique                   24           4                 96
top     2008-10-31 00:00:00    region_1  48.88984800150122
freq                      4          24                  1
first   2008-01-31 00:00:00         NaN                NaN
last    2009-12-31 00:00:00         NaN                NaN

非常感谢任何帮助!谢谢

鉴于区域的定义方式,很难避免迭代生成区域的掩码,但是一旦构建了这些区域(例如使用下面的代码),我认为以下将非常有效:

regions = xr.concat(
    [(region_ds.region == region_id).expand_dims(region=[region])
     for region_id, region in region_lookup.items()], 
    dim='region'
)
result = ds.precip.where(regions).mean(['lat', 'lon'])

这将生成一个具有 'time''region' 维度的 DataArray,其中每个点的值是给定区域在给定时间的平均值。如果也需要的话,将其扩展到面积加权平均值会很简单。


生成相同结果的替代选项是:

regions = xr.DataArray(
    list(region_lookup.keys()),
    coords=[list(region_lookup.values())],
    dims=['region']
)
result = ds.precip.where(regions == region_ds.region).mean(['lat', 'lon'])

这里 regions 基本上只是 region_lookup 字典的 DataArray 表示。