如何获取xarray中最大值的坐标?

How to get the coordinates of the maximum in xarray?

简单问题:我不仅想要最大值,还想要它在 xarray DataArray 中的坐标。怎么做?

我当然可以自己写一个简单的reduce函数,但我想知道xarray有没有内置的东西?

idxmax() 方法在 xarray 中是 very welcome,但还没有人开始实施它。

目前,您可以结合argmaxisel找到最大值的坐标:

>>> array = xarray.DataArray(
...    [[1, 2, 3], [3, 2, 1]],
...    dims=['x', 'y'],
...    coords={'x': [1, 2], 'y': ['a', 'b', 'c']})

>>> array
<xarray.DataArray (x: 2, y: 3)>
array([[1, 2, 3],
       [3, 2, 1]])
Coordinates:
  * x        (x) int64 1 2
  * y        (y) <U1 'a' 'b' 'c'

>>> array.isel(y=array.argmax('y'))
<xarray.DataArray (x: 2)>
array([3, 3])
Coordinates:
  * x        (x) int64 1 2
    y        (x) <U1 'c' 'a'

这可能是 .max() 在任何情况下都应该做的!不幸的是,我们还没有到那一步。

问题是它还没有按照我们希望的方式在多个维度上泛化到最大值:

>>> array.argmax()  # what??
<xarray.DataArray ()>
array(2)

问题是它会自动变平,就像 np.argmax。相反,我们可能想要类似元组数组或数组元组的东西,指示最大值的原始整数坐标。也欢迎对此做出贡献——有关详细信息,请参阅 this issue

更新:

xarray 现在有 idxmax 方法来选择一维最大值的坐标:


In [8]: da = xr.DataArray(
   ...:     np.random.rand(2,3),
   ...:     dims=list('ab'),
   ...:     coords=dict(a=list('xy'), b=list('ijk'))
   ...: )


In [14]: da
Out[14]:
<xarray.DataArray (a: 2, b: 3)>
array([[0.63059257, 0.00155463, 0.60763418],
       [0.19680788, 0.43953352, 0.05602777]])
Coordinates:
  * a        (a) <U1 'x' 'y'
  * b        (b) <U1 'i' 'j' 'k'

In [13]: da.idxmax('a')
Out[13]:
<xarray.DataArray 'a' (b: 3)>
array(['x', 'y', 'x'], dtype=object)
Coordinates:
  * b        (b) <U1 'i' 'j' 'k'


不过,以下答案仍然适用于多个维度上的最大值。


您可以使用da.where()根据最大值过滤:

In [17]: da = xr.DataArray(
             np.random.rand(2,3), 
             dims=list('ab'), 
             coords=dict(a=list('xy'), b=list('ijk'))
         )

In [18]: da.where(da==da.max(), drop=True).squeeze()
Out[18]:
<xarray.DataArray ()>
array(0.96213673)
Coordinates:
    a        <U1 'x'
    b        <U1 'j'

编辑:更新示例以更清楚地显示索引,现在 xarray 没有默认索引

你也可以使用堆栈:

假设数据是一个具有时间、经度、纬度的 3d 变量,并且您想要随时间变化的最大值的坐标。

stackdata = data.stack(z=('lon', 'lat'))
maxi = stackdata.argmax(axis=1)
maxipos = stackdata['z'][maxi]
lonmax = [maxipos.values[itr][0] for itr in range(ntime)]
latmax = [maxipos.values[itr][1] for itr in range(ntime)]

这将return xarray 数据数组中最大值的坐标点。

max = xarraydata.where(xarraydata==xarraydata.max(), drop=True).squeeze()