xarray.apply_ufunc() with GroupBy: 意外的维数
xarray.apply_ufunc() with GroupBy: unexpected number of dimensions
我正在使用 xarray.apply_ufunc() 将函数应用于 xarray.DataArray。它适用于某些 NetCDF,但无法适用于其他在维度、坐标等方面似乎具有可比性的 NetCDF。但是,代码适用的 NetCDF 与代码失败的 NetCDF 之间肯定存在某些不同之处,希望有人可以在看到下面列出的文件的代码和一些元数据后评论问题是什么。
我 运行 执行计算的代码是这样的:
# open the precipitation NetCDF as an xarray DataSet object
dataset = xr.open_dataset(kwrgs['netcdf_precip'])
# get the precipitation array, over which we'll compute the SPI
da_precip = dataset[kwrgs['var_name_precip']]
# stack the lat and lon dimensions into a new dimension named point, so at each lat/lon
# we'll have a time series for the geospatial point, and group by these points
da_precip_groupby = da_precip.stack(point=('lat', 'lon')).groupby('point')
# apply the SPI function to the data array
da_spi = xr.apply_ufunc(indices.spi,
da_precip_groupby)
# unstack the array back into original dimensions
da_spi = da_spi.unstack('point')
有效的 NetCDF 如下所示:
>>> import xarray as xr
>>> ds_good = xr.open_dataset("good.nc")
>>> ds_good
<xarray.Dataset>
Dimensions: (lat: 38, lon: 87, time: 1466)
Coordinates:
* lat (lat) float32 24.5625 25.229166 25.895834 ... 48.5625 49.229168
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2017-02-01
Data variables:
prcp (lat, lon, time) float32 ...
Attributes:
Conventions: CF-1.6, ACDD-1.3
ncei_template_version: NCEI_NetCDF_Grid_Template_v2.0
title: nClimGrid
naming_authority: gov.noaa.ncei
standard_name_vocabulary: Standard Name Table v35
institution: National Centers for Environmental Information...
geospatial_lat_min: 24.5625
geospatial_lat_max: 49.354168
geospatial_lon_min: -124.6875
geospatial_lon_max: -67.020836
geospatial_lat_units: degrees_north
geospatial_lon_units: degrees_east
NCO: 4.7.1
nco_openmp_thread_number: 1
>>> ds_good.prcp
<xarray.DataArray 'prcp' (lat: 38, lon: 87, time: 1466)>
[4846596 values with dtype=float32]
Coordinates:
* lat (lat) float32 24.5625 25.229166 25.895834 ... 48.5625 49.229168
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2017-02-01
Attributes:
valid_min: 0.0
units: millimeter
valid_max: 2000.0
standard_name: precipitation_amount
long_name: Precipitation, monthly total
失败的 NetCDF 如下所示:
>>> ds_bad = xr.open_dataset("bad.nc") >>> ds_bad
<xarray.Dataset>
Dimensions: (lat: 38, lon: 87, time: 1483)
Coordinates:
* lat (lat) float32 49.3542 48.687534 48.020866 ... 25.3542 24.687532
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2018-07-01
Data variables:
prcp (lat, lon, time) float32 ...
Attributes:
date_created: 2018-02-15 10:29:25.485927
date_modified: 2018-02-15 10:29:25.486042
Conventions: CF-1.6, ACDD-1.3
ncei_template_version: NCEI_NetCDF_Grid_Template_v2.0
title: nClimGrid
naming_authority: gov.noaa.ncei
standard_name_vocabulary: Standard Name Table v35
institution: National Centers for Environmental Information...
geospatial_lat_min: 24.562532
geospatial_lat_max: 49.3542
geospatial_lon_min: -124.6875
geospatial_lon_max: -67.020836
geospatial_lat_units: degrees_north
geospatial_lon_units: degrees_east
>>> ds_bad.prcp
<xarray.DataArray 'prcp' (lat: 38, lon: 87, time: 1483)>
[4902798 values with dtype=float32]
Coordinates:
* lat (lat) float32 49.3542 48.687534 48.020866 ... 25.3542 24.687532
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2018-07-01
Attributes:
valid_min: 0.0
long_name: Precipitation, monthly total
standard_name: precipitation_amount
units: millimeter
valid_max: 2000.0
当我 运行 针对上面第一个文件的代码时,它可以正常工作。使用第二个文件时,出现如下错误:
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/multiprocessing/pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
return list(map(*args))
File "/home/paperspace/git/climate_indices/scripts/process_grid_ufunc.py", line 278, in compute_write_spi
kwargs=args_dict)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 974, in apply_ufunc
return apply_groupby_ufunc(this_apply, *args)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 432, in apply_groupby_ufunc
applied_example, applied = peek_at(applied)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/utils.py", line 133, in peek_at
peek = next(gen)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 431, in <genexpr>
applied = (func(*zipped_args) for zipped_args in zip(*iterators))
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 987, in apply_ufunc
exclude_dims=exclude_dims)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 211, in apply_dataarray_ufunc
result_var = func(*data_vars)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 579, in apply_variable_ufunc
.format(data.ndim, len(dims), dims))
ValueError: applied function returned data with unexpected number of dimensions: 1 vs 2, for dimensions ('time', 'point')
任何人都可以评论可能是什么问题吗?
事实证明,输入纬度坐标值时出现问题的NetCDF文件是按降序排列的。 xarray.apply_ufunc()
似乎要求坐标值按升序排列,至少是为了避免这个特定问题。在使用 NetCDF 文件作为 xarray.
的输入之前,使用 NCO 的 ncpdq 命令反转有问题的维度的坐标值可以很容易地解决这个问题。
感谢您的回复。
有时,似乎通过按升序方式对维度进行排序可以很好地解决xr.apply_ufunc的问题。然而,有时这种策略是不够的。
另一种替代解决方案是将外部用户函数广播的坐标堆叠到一个新维度中(即:将 'Longitude' 和 'Latitude' 维度堆叠到一个名为 'Grid_Point').堆叠之后,可以对这个新维度 "Grid_Point" 进行 groupby 操作并应用 xr.apply_ufunc.
这里是一个示例,说明如何从基于每个像素的 netcdf 温度数据集中的高斯分布('mean' 和 'standard deviation')导出相应的统计矩。
import xarray as xr
# http://xarray.pydata.org/en/stable/dask.html
from scipy import stats
from dask.diagnostics import ProgressBar
import numpy as np
import warnings
def get_params_from_distribution(data, distribution='exponweib'):
distribution = getattr(stats, distribution)
if np.all(np.isnan(data)):
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore")
try:
temp_data = distribution.rvs(1, size=10)
except:
try:
temp_data = distribution.rvs(1, 1, size=10)
except:
temp_data = distribution.rvs(1, 1, 1, size=10)
n_params = len(distribution.fit(temp_data))
return data[:n_params]
else:
return list(distribution.fit(data))
def get_params_vectorized_from_stacked(stacked_data, distribution='exponweib',
dask='allowed',
input_core_dims='time',
output_core_dims = 'stat_moments',
output_dtypes=[xr.core.dataset.Dataset]):
kwargs = {'distribution': distribution}
with ProgressBar():
da_spi = xr.apply_ufunc(get_params_from_distribution,
stacked_data,
exclude_dims={input_core_dims},
kwargs=kwargs,
input_core_dims=[[input_core_dims]],
output_core_dims=[[output_core_dims]],
dask=dask,
output_dtypes=[output_dtypes]).compute()
return da_spi
def stack_ds(ds, dims=['lon', 'lat'], stacked_dim_name='point'):
return ds.stack({stacked_dim_name:dims})
def main_pdf_u_function_getter(ds,
dims_to_stack=['lon', 'lat'],
stacked_dim_name='point',
distribution_name='exponweib',
dask='allowed',
input_core_dims = 'time',
output_core_dims = 'stat_moments',
output_dtypes=[float]):
ds_stacked = stack_ds(ds, dims_to_stack, stacked_dim_name) # observation 1
ds_groupby = ds_stacked.groupby(stacked_dim_name) # observation 1
results = get_params_vectorized_from_stacked(ds_groupby,
distribution=distribution_name,
dask=dask,
output_core_dims=output_core_dims,
input_core_dims=input_core_dims,
output_dtypes=output_dtypes)
return results.unstack(stacked_dim_name)
if '__main__' == __name__:
ds = xr.tutorial.open_dataset('air_temperature').sortby(['lat', 'lon', 'time'])
R = main_pdf_u_function_getter(ds,
dask='parallelized',
dims_to_stack=['lon', 'lat'],
stacked_dim_name='point',
distribution_name='norm')
print(R)
import matplotlib.pyplot as plt
fig, ax= plt.subplots(1,2)
ax = ax.ravel()
for moment in range(R.dims['stat_moments']):
R['air'].isel({'stat_moments':moment}).plot(ax=ax[moment], cmap='viridis')
请注意,在上面的代码中,有一个注释行 "observation 1"。这些是确保整个算法正常工作的主线。它在 ufunction 操作之前对广播维度进行堆叠。
尽管给出了解决方案(有效),但我仍然不知道为什么要在 xr.apply_ufunc 之前进行堆叠。这是一个悬而未决的问题。
此致,
我正在使用 xarray.apply_ufunc() 将函数应用于 xarray.DataArray。它适用于某些 NetCDF,但无法适用于其他在维度、坐标等方面似乎具有可比性的 NetCDF。但是,代码适用的 NetCDF 与代码失败的 NetCDF 之间肯定存在某些不同之处,希望有人可以在看到下面列出的文件的代码和一些元数据后评论问题是什么。
我 运行 执行计算的代码是这样的:
# open the precipitation NetCDF as an xarray DataSet object
dataset = xr.open_dataset(kwrgs['netcdf_precip'])
# get the precipitation array, over which we'll compute the SPI
da_precip = dataset[kwrgs['var_name_precip']]
# stack the lat and lon dimensions into a new dimension named point, so at each lat/lon
# we'll have a time series for the geospatial point, and group by these points
da_precip_groupby = da_precip.stack(point=('lat', 'lon')).groupby('point')
# apply the SPI function to the data array
da_spi = xr.apply_ufunc(indices.spi,
da_precip_groupby)
# unstack the array back into original dimensions
da_spi = da_spi.unstack('point')
有效的 NetCDF 如下所示:
>>> import xarray as xr
>>> ds_good = xr.open_dataset("good.nc")
>>> ds_good
<xarray.Dataset>
Dimensions: (lat: 38, lon: 87, time: 1466)
Coordinates:
* lat (lat) float32 24.5625 25.229166 25.895834 ... 48.5625 49.229168
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2017-02-01
Data variables:
prcp (lat, lon, time) float32 ...
Attributes:
Conventions: CF-1.6, ACDD-1.3
ncei_template_version: NCEI_NetCDF_Grid_Template_v2.0
title: nClimGrid
naming_authority: gov.noaa.ncei
standard_name_vocabulary: Standard Name Table v35
institution: National Centers for Environmental Information...
geospatial_lat_min: 24.5625
geospatial_lat_max: 49.354168
geospatial_lon_min: -124.6875
geospatial_lon_max: -67.020836
geospatial_lat_units: degrees_north
geospatial_lon_units: degrees_east
NCO: 4.7.1
nco_openmp_thread_number: 1
>>> ds_good.prcp
<xarray.DataArray 'prcp' (lat: 38, lon: 87, time: 1466)>
[4846596 values with dtype=float32]
Coordinates:
* lat (lat) float32 24.5625 25.229166 25.895834 ... 48.5625 49.229168
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2017-02-01
Attributes:
valid_min: 0.0
units: millimeter
valid_max: 2000.0
standard_name: precipitation_amount
long_name: Precipitation, monthly total
失败的 NetCDF 如下所示:
>>> ds_bad = xr.open_dataset("bad.nc") >>> ds_bad
<xarray.Dataset>
Dimensions: (lat: 38, lon: 87, time: 1483)
Coordinates:
* lat (lat) float32 49.3542 48.687534 48.020866 ... 25.3542 24.687532
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2018-07-01
Data variables:
prcp (lat, lon, time) float32 ...
Attributes:
date_created: 2018-02-15 10:29:25.485927
date_modified: 2018-02-15 10:29:25.486042
Conventions: CF-1.6, ACDD-1.3
ncei_template_version: NCEI_NetCDF_Grid_Template_v2.0
title: nClimGrid
naming_authority: gov.noaa.ncei
standard_name_vocabulary: Standard Name Table v35
institution: National Centers for Environmental Information...
geospatial_lat_min: 24.562532
geospatial_lat_max: 49.3542
geospatial_lon_min: -124.6875
geospatial_lon_max: -67.020836
geospatial_lat_units: degrees_north
geospatial_lon_units: degrees_east
>>> ds_bad.prcp
<xarray.DataArray 'prcp' (lat: 38, lon: 87, time: 1483)>
[4902798 values with dtype=float32]
Coordinates:
* lat (lat) float32 49.3542 48.687534 48.020866 ... 25.3542 24.687532
* lon (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
* time (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2018-07-01
Attributes:
valid_min: 0.0
long_name: Precipitation, monthly total
standard_name: precipitation_amount
units: millimeter
valid_max: 2000.0
当我 运行 针对上面第一个文件的代码时,它可以正常工作。使用第二个文件时,出现如下错误:
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/multiprocessing/pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
return list(map(*args))
File "/home/paperspace/git/climate_indices/scripts/process_grid_ufunc.py", line 278, in compute_write_spi
kwargs=args_dict)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 974, in apply_ufunc
return apply_groupby_ufunc(this_apply, *args)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 432, in apply_groupby_ufunc
applied_example, applied = peek_at(applied)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/utils.py", line 133, in peek_at
peek = next(gen)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 431, in <genexpr>
applied = (func(*zipped_args) for zipped_args in zip(*iterators))
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 987, in apply_ufunc
exclude_dims=exclude_dims)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 211, in apply_dataarray_ufunc
result_var = func(*data_vars)
File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 579, in apply_variable_ufunc
.format(data.ndim, len(dims), dims))
ValueError: applied function returned data with unexpected number of dimensions: 1 vs 2, for dimensions ('time', 'point')
任何人都可以评论可能是什么问题吗?
事实证明,输入纬度坐标值时出现问题的NetCDF文件是按降序排列的。 xarray.apply_ufunc()
似乎要求坐标值按升序排列,至少是为了避免这个特定问题。在使用 NetCDF 文件作为 xarray.
感谢您的回复。
有时,似乎通过按升序方式对维度进行排序可以很好地解决xr.apply_ufunc的问题。然而,有时这种策略是不够的。
另一种替代解决方案是将外部用户函数广播的坐标堆叠到一个新维度中(即:将 'Longitude' 和 'Latitude' 维度堆叠到一个名为 'Grid_Point').堆叠之后,可以对这个新维度 "Grid_Point" 进行 groupby 操作并应用 xr.apply_ufunc.
这里是一个示例,说明如何从基于每个像素的 netcdf 温度数据集中的高斯分布('mean' 和 'standard deviation')导出相应的统计矩。
import xarray as xr
# http://xarray.pydata.org/en/stable/dask.html
from scipy import stats
from dask.diagnostics import ProgressBar
import numpy as np
import warnings
def get_params_from_distribution(data, distribution='exponweib'):
distribution = getattr(stats, distribution)
if np.all(np.isnan(data)):
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore")
try:
temp_data = distribution.rvs(1, size=10)
except:
try:
temp_data = distribution.rvs(1, 1, size=10)
except:
temp_data = distribution.rvs(1, 1, 1, size=10)
n_params = len(distribution.fit(temp_data))
return data[:n_params]
else:
return list(distribution.fit(data))
def get_params_vectorized_from_stacked(stacked_data, distribution='exponweib',
dask='allowed',
input_core_dims='time',
output_core_dims = 'stat_moments',
output_dtypes=[xr.core.dataset.Dataset]):
kwargs = {'distribution': distribution}
with ProgressBar():
da_spi = xr.apply_ufunc(get_params_from_distribution,
stacked_data,
exclude_dims={input_core_dims},
kwargs=kwargs,
input_core_dims=[[input_core_dims]],
output_core_dims=[[output_core_dims]],
dask=dask,
output_dtypes=[output_dtypes]).compute()
return da_spi
def stack_ds(ds, dims=['lon', 'lat'], stacked_dim_name='point'):
return ds.stack({stacked_dim_name:dims})
def main_pdf_u_function_getter(ds,
dims_to_stack=['lon', 'lat'],
stacked_dim_name='point',
distribution_name='exponweib',
dask='allowed',
input_core_dims = 'time',
output_core_dims = 'stat_moments',
output_dtypes=[float]):
ds_stacked = stack_ds(ds, dims_to_stack, stacked_dim_name) # observation 1
ds_groupby = ds_stacked.groupby(stacked_dim_name) # observation 1
results = get_params_vectorized_from_stacked(ds_groupby,
distribution=distribution_name,
dask=dask,
output_core_dims=output_core_dims,
input_core_dims=input_core_dims,
output_dtypes=output_dtypes)
return results.unstack(stacked_dim_name)
if '__main__' == __name__:
ds = xr.tutorial.open_dataset('air_temperature').sortby(['lat', 'lon', 'time'])
R = main_pdf_u_function_getter(ds,
dask='parallelized',
dims_to_stack=['lon', 'lat'],
stacked_dim_name='point',
distribution_name='norm')
print(R)
import matplotlib.pyplot as plt
fig, ax= plt.subplots(1,2)
ax = ax.ravel()
for moment in range(R.dims['stat_moments']):
R['air'].isel({'stat_moments':moment}).plot(ax=ax[moment], cmap='viridis')
请注意,在上面的代码中,有一个注释行 "observation 1"。这些是确保整个算法正常工作的主线。它在 ufunction 操作之前对广播维度进行堆叠。
尽管给出了解决方案(有效),但我仍然不知道为什么要在 xr.apply_ufunc 之前进行堆叠。这是一个悬而未决的问题。
此致,