xr.apply_ufunc 过滤 3d x 数组

xr.apply_ufunc to filter 3d x array

今天我 运行 解决了在 3d xarray 数据阵列上正确应用 xr.apply_ufunc 的问题。

我想沿 'time' 维度过滤数组。 因此最终的产品应该与输入数组具有相同的维度,并且应该只包含过滤后的数据而不是每月数据。为此,我编写了以下两个函数:

import xarray as xr
from scipy import signal
import numpy as np

def butter_filt(x,filt_year,fs,order_butter):

    #filt_year = 1 #1 year
    #fs = 12 #monthly data
    #fn = fs/2; # Nyquist Frequency
    fc = (1/filt_year)/2 # cut off frequency 1sample/ 1year = (1/1)/2 equals 1 year filter (two half cycles/sample)
    #fc = (1/2)/2 # cut off frequency 1sample/ 2year = (1/1)/2 equals 2 year filter (two half cycles/sample)
    #fc = (1/4)/2 # cut off frequency 1sample/ 4year = (1/1)/2 equals 4 year filter (two half cycles/sample)
    b, a = signal.butter(order_butter, fc, 'low', fs=fs, output='ba')

    # Check NA values
    co = np.count_nonzero(~np.isnan(x))
    if co < 4: # If fewer than 4 observations return -9999
        return np.empty(x.shape)
    else:
        return signal.filtfilt(b, a, x)

def filtfilt_butter(x,filt_year,fs,order_butter,dim='time'):
    # x ...... xr data array
    # dims .... dimension along which to apply function    
    filt= xr.apply_ufunc(butter_filt, x,filt_year,fs,order_butter,
                     input_core_dims=[[dim], [], [], []],
                       dask='parallelized')

    return filt

x_uv = filtfilt_butter(ds.x,
                   filt_year=1,
                   fs=12,
                   order_butter=2,
                   dim='time')

我尝试了 butter_filt 函数,它本身运行良好,所以 filtfilt_butter 中存在某种问题。尝试计算 x_uv 会出现以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<timed exec> in <module>

<ipython-input-86-c470080a0817> in filtfilt_butter(x, filt_year, fs, order_butter, dim)
     20     # x ...... xr data array
     21     # dims .... dimension aong which to apply function
---> 22     filt= xr.apply_ufunc(butter_filt, x,filt_year,fs,order_butter,
     23                          input_core_dims=[[dim], [], [], []],
     24                            dask='parallelized')

~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in 
apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, 
dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, *args)
   1056         )
   1057     elif any(isinstance(a, DataArray) for a in args):
-> 1058         return apply_dataarray_vfunc(
   1059             variables_vfunc,
   1060             *args,

~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in 
apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    231 
    232     data_vars = [getattr(a, "variable", a) for a in args]
--> 233     result_var = func(*data_vars)
    234 
    235     if signature.num_outputs > 1:

~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in 
 apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, output_sizes, keep_attrs, 
meta, *args)
    621         data = as_compatible_data(data)
    622         if data.ndim != len(dims):
--> 623             raise ValueError(
    624                 "applied function returned data with unexpected "
    625                 "number of dimensions: {} vs {}, for dimensions {}".format(

ValueError: applied function returned data with unexpected number of dimensions: 3 vs 2, for 
dimensions ('deptht', 'km')

我该如何解决这个问题?

我通过定义输入和输出维度找到了解决这个问题的方法:

def butter_filt(x,filt_year,fs,order_butter):

#filt_year = 1 #1 year
#fs = 12 #monthly data
#fn = fs/2; # Nyquist Frequency
fc = (1/filt_year)/2 # cut off frequency 1sample/ 1year = (1/1)/2 equals 1 year filter (two half cycles/sample)
#fc = (1/2)/2 # cut off frequency 1sample/ 2year = (1/1)/2 equals 2 year filter (two half cycles/sample)
#fc = (1/4)/2 # cut off frequency 1sample/ 4year = (1/1)/2 equals 4 year filter (two half cycles/sample)
b, a = signal.butter(order_butter, fc, 'low', fs=fs, output='ba')

return signal.filtfilt(b, a, x)




def filtfilt_butter(x,filt_year,fs,order_butter,dim='time'):
# x ...... xr data array
# dims .... dimension aong which to apply function    
filt= xr.apply_ufunc(
    butter_filt,  # first the function
    x,# now arguments in the order expected by 'butter_filt'
    filt_year,  # as above
    fs,  # as above
    order_butter,  # as above
    input_core_dims=[["deptht","km","time"], [], [],[]],  # list with one entry per arg
    output_core_dims=[["deptht","km","time"]],  # returned data has 3 dimension
    exclude_dims=set(("time",)),  # dimensions allowed to change size. Must be a set!
    vectorize=True,  # loop over non-core dims
)

return filt

apply_ufunc 期望输出形状是初始形状减去 input_core_dims 加上 output_core_dims 。在您的情况下,您将 time 作为输入核心暗淡传递是正确的,因为您要确保它被移动到最后一个维度,因此使用 axis=-1 可以正常工作。

因此您需要使用 output_core_dims 让 xarray 获得 3d 输出数组。您也可以使用 time

有关 apply_ufunc 个参数的更详细说明,请参阅