将多路 pandas.crosstab 转换为 xarray

Convert a multiway pandas.crosstab to an xarray

我想从我的 pandas 数据帧创建一个多路应急事件 table 并将其存储在 xarray 中。在我看来,使用 pandas.crosstab followed by DataFrame.to_xarray() 应该足够直截了当,但我在 pandas v1.1.5 中收到“TypeError: Cannot interpret 'interval[int64]' as a data type”。 (v1.0.1 给出“ValueError:所有数组的长度必须相同”)。

In [1]: import numpy as np
   ...: import pandas as pd
   ...: pd.__version__
Out[1]: '1.1.5'

In [2]: import xarray as xr
   ...: xr.__version__
Out[2]: '0.17.0'

In [3]: n = 100
   ...: np.random.seed(42)
   ...: x = pd.cut(np.random.uniform(low=0, high=3, size=n), range(5))
   ...: x
Out[3]: 
[(1, 2], (2, 3], (2, 3], (1, 2], (0, 1], ..., (1, 2], (1, 2], (1, 2], (0, 1], (0, 1]]
Length: 100
Categories (4, interval[int64]): [(0, 1] < (1, 2] < (2, 3] < (3, 4]]

In [4]: x.value_counts().sort_index()
Out[4]: 
(0, 1]    41
(1, 2]    28
(2, 3]    31
(3, 4]     0
dtype: int64

请注意,我需要 table 包含空类别,例如 (3, 4)。

In [6]: idx=pd.date_range('2001-01-01', periods=n, freq='8H')
   ...: df = pd.DataFrame({'x': x}, index=idx)
   ...: df['xlag'] = df.x.shift(1, 'D')
   ...: df['h'] = df.index.hour
   ...: xtab = pd.crosstab([df.h, df.xlag], df.x, dropna=False, normalize='index')
   ...: xtab
Out[6]: 
x            (0, 1]    (1, 2]    (2, 3]  (3, 4]
h  xlag                                        
0  (0, 1]  0.000000  0.700000  0.300000     0.0
   (1, 2]  0.470588  0.411765  0.117647     0.0
   (2, 3]  0.500000  0.333333  0.166667     0.0
   (3, 4]  0.000000  0.000000  0.000000     0.0
8  (0, 1]  0.588235  0.000000  0.411765     0.0
   (1, 2]  1.000000  0.000000  0.000000     0.0
   (2, 3]  0.428571  0.142857  0.428571     0.0
   (3, 4]  0.000000  0.000000  0.000000     0.0
16 (0, 1]  0.333333  0.250000  0.416667     0.0
   (1, 2]  0.444444  0.222222  0.333333     0.0
   (2, 3]  0.454545  0.363636  0.181818     0.0
   (3, 4]  0.000000  0.000000  0.000000     0.0

很好,但我的实际应用程序有更多类别和更多维度,所以这似乎是 xarray 的一个明确用例,但我收到错误:

In [8]: xtab.to_xarray()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-aaedf730bb97> in <module>
----> 1 xtab.to_xarray()

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/pandas/core/generic.py in to_xarray(self)
   2818             return xarray.DataArray.from_series(self)
   2819         else:
-> 2820             return xarray.Dataset.from_dataframe(self)
   2821 
   2822     @Substitution(returns=fmt.return_docstring)

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/dataset.py in from_dataframe(cls, dataframe, sparse)
   5131             obj._set_sparse_data_from_dataframe(idx, arrays, dims)
   5132         else:
-> 5133             obj._set_numpy_data_from_dataframe(idx, arrays, dims)
   5134         return obj
   5135 

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/dataset.py in _set_numpy_data_from_dataframe(self, idx, arrays, dims)
   5062                 data = np.zeros(shape, values.dtype)
   5063             data[indexer] = values
-> 5064             self[name] = (dims, data)
   5065 
   5066     @classmethod

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/dataset.py in __setitem__(self, key, value)
   1427             )
   1428 
-> 1429         self.update({key: value})
   1430 
   1431     def __delitem__(self, key: Hashable) -> None:

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/dataset.py in update(self, other)
   3897         Dataset.assign
   3898         """
-> 3899         merge_result = dataset_update_method(self, other)
   3900         return self._replace(inplace=True, **merge_result._asdict())
   3901 

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/merge.py in dataset_update_method(dataset, other)
    958         priority_arg=1,
    959         indexes=indexes,
--> 960         combine_attrs="override",
    961     )

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/merge.py in merge_core(objects, compat, join, combine_attrs, priority_arg, explicit_coords, indexes, fill_value)
    609     coerced = coerce_pandas_values(objects)
    610     aligned = deep_align(
--> 611         coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value
    612     )
    613     collected = collect_variables_and_indexes(aligned)

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/alignment.py in deep_align(objects, join, copy, indexes, exclude, raise_on_invalid, fill_value)
    428         indexes=indexes,
    429         exclude=exclude,
--> 430         fill_value=fill_value,
    431     )
    432 

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/alignment.py in align(join, copy, indexes, exclude, fill_value, *objects)
    352         if not valid_indexers:
    353             # fast path for no reindexing necessary
--> 354             new_obj = obj.copy(deep=copy)
    355         else:
    356             new_obj = obj.reindex(

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/dataset.py in copy(self, deep, data)
   1218         """
   1219         if data is None:
-> 1220             variables = {k: v.copy(deep=deep) for k, v in self._variables.items()}
   1221         elif not utils.is_dict_like(data):
   1222             raise ValueError("Data must be dict-like")

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/dataset.py in <dictcomp>(.0)
   1218         """
   1219         if data is None:
-> 1220             variables = {k: v.copy(deep=deep) for k, v in self._variables.items()}
   1221         elif not utils.is_dict_like(data):
   1222             raise ValueError("Data must be dict-like")

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/variable.py in copy(self, deep, data)
   2632         """
   2633         if data is None:
-> 2634             data = self._data.copy(deep=deep)
   2635         else:
   2636             data = as_compatible_data(data)

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/indexing.py in copy(self, deep)
   1484         # 8000341
   1485         array = self.array.copy(deep=True) if deep else self.array
-> 1486         return PandasIndexAdapter(array, self._dtype)

/opt/scitools/environments/default/2021_03_18-1/lib/python3.6/site-packages/xarray/core/indexing.py in __init__(self, array, dtype)
   1407                 dtype_ = array.dtype
   1408         else:
-> 1409             dtype_ = np.dtype(dtype)
   1410         self._dtype = dtype_
   1411 

TypeError: Cannot interpret 'interval[int64]' as a data type


在使用 pandas.crosstab 之前,我可以通过将 x(和 xlag)转换为不同的数据类型而不是 pandas.Categorical 来避免错误,但随后我丢失了所有需要保留的空类别我的真实应用。

这里的问题不是使用 CategoricalIndex,而是类别标签 (x.categories) 是 IntervalIndexxarray 不喜欢。

为了解决这个问题,您可以简单地将 x 变量中的类别替换为它们的字符串表示形式,这会强制 x.categories 成为“对象”数据类型而不是“间隔[int64]” “数据类型:

x = (
    pd.cut(np.random.uniform(low=0, high=3, size=n), range(5))
    .rename_categories(str)
)

然后像您已经完成的那样计算您的交叉表,它应该可以工作!


要让您的数据集处于您想要的坐标(我认为),您需要做的就是将所有内容堆叠在一个 MultiIndex 行形状中。 (而不是交叉表 MultiIndex 行/Index 列形状)。

xtab = (
    pd.crosstab([df.h, df.xlag], df.x, dropna=False, normalize="index")
    .stack()
    .reorder_levels(["x", "h", "xlag"])
    .sort_index()
)
xtab.to_xarray()

如果您想缩短代码并丢失一些索引级别的显式排序,您也可以使用 unstack 而不是堆栈,它会立即为您提供正确的排序:

xtab = (
    pd.crosstab([df.h, df.xlag], df.x, dropna=False, normalize="index")
    .unstack([0, 1])
)
xtab.to_xarray()

无论您使用 stack() 还是 unstack([0, 1]) 方法,您都会得到以下输出:

<xarray.DataArray (x: 4, h: 3, xlag: 4)>
array([[[0.        , 0.47058824, 0.5       , 0.        ],
        [0.58823529, 1.        , 0.42857143, 0.        ],
        [0.33333333, 0.44444444, 0.45454545, 0.        ]],

       [[0.7       , 0.41176471, 0.33333333, 0.        ],
        [0.        , 0.        , 0.14285714, 0.        ],
        [0.25      , 0.22222222, 0.36363636, 0.        ]],

       [[0.3       , 0.11764706, 0.16666667, 0.        ],
        [0.41176471, 0.        , 0.42857143, 0.        ],
        [0.41666667, 0.33333333, 0.18181818, 0.        ]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]]])
Coordinates:
  * x        (x) object '(0, 1]' '(1, 2]' '(2, 3]' '(3, 4]'
  * h        (h) int64 0 8 16
  * xlag     (xlag) object '(0, 1]' '(1, 2]' '(2, 3]' '(3, 4]'

@Cameron-Riddell 的回答是解决我的问题的关键,但还有一些额外的重塑问题需要解决。按照他的建议将 rename_categories(str) 应用于我的 x 变量,然后按照我的问题进行操作允许最后一行工作:

In [8]: xtab = pd.crosstab([df.h, df.xlag], df.x, dropna=False, normalize='index')
   ...: xtab.to_xarray()
Out[8]: 
<xarray.Dataset>
Dimensions:  (h: 3, xlag: 4)
Coordinates:
  * h        (h) int64 0 8 16
  * xlag     (xlag) object '(0, 1]' '(1, 2]' '(2, 3]' '(3, 4]'
Data variables:
    (0, 1]   (h, xlag) float64 0.0 0.4706 0.5 0.0 ... 0.3333 0.4444 0.4545 0.0
    (1, 2]   (h, xlag) float64 0.7 0.4118 0.3333 0.0 ... 0.25 0.2222 0.3636 0.0
    (2, 3]   (h, xlag) float64 0.3 0.1176 0.1667 0.0 ... 0.3333 0.1818 0.0
    (3, 4]   (h, xlag) float64 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

但我想要一个带有一个变量的 3 维数组,而不是一个带有 3 个变量的二维数组。要转换它,我需要申请 .to_array(dim='x')。但是我的尺寸是按 xhxlag 的顺序排列的,我显然不希望 h 在中间,所以我还需要转置它们:

In [9]: xtab.to_xarray().to_array(dim='x').transpose('h', 'xlag', 'x')
Out[9]: 
<xarray.DataArray (h: 3, xlag: 4, x: 4)>
array([[[0.        , 0.7       , 0.3       , 0.        ],
        [0.47058824, 0.41176471, 0.11764706, 0.        ],
        [0.5       , 0.33333333, 0.16666667, 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.58823529, 0.        , 0.41176471, 0.        ],
        [1.        , 0.        , 0.        , 0.        ],
        [0.42857143, 0.14285714, 0.42857143, 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.33333333, 0.25      , 0.41666667, 0.        ],
        [0.44444444, 0.22222222, 0.33333333, 0.        ],
        [0.45454545, 0.36363636, 0.18181818, 0.        ],
        [0.        , 0.        , 0.        , 0.        ]]])
Coordinates:
  * h        (h) int64 0 8 16
  * xlag     (xlag) object '(0, 1]' '(1, 2]' '(2, 3]' '(3, 4]'
  * x        (x) <U6 '(0, 1]' '(1, 2]' '(2, 3]' '(3, 4]'

这就是我的设想!它的显示类似于 pd.crosstab,但它是一个 3 维 xarray,而不是具有多索引的 pandas 数据帧。这将在我的程序的后续阶段更容易处理(交叉表只是一个中间步骤,本身不是结果)。

我必须说这最终比我预期的要复杂......我在 2017 年发现了@kilojoules 的一个问题“" to which 开始“似乎确实有一个过渡到 xarray 来做在多维数组上工作。” 对我来说似乎很遗憾没有 pd.crosstab 版本 returns xarray - 或者我要求更多 pandas-xarray 集成而不是可能吗?