用于 xarray 的 JAX pytree

JAX pytree for xarray

是否可以基于 xarray.DataArray 创建一个 pytree(用于 JAX)?

关键问题似乎是 xarray.DataArraydata 属性被构造为输入数组的 numpy.ndarray 视图(例如 DeviceArray

import jax.numpy as jnp
import xarray as xr

da = xr.DataArray(
    data=jnp.array([2.0,3.0]),
    dims=("var"),
    coords={"var": ["A","B"]},
)

>>> type(da.data)

    numpy.ndarray

Flattening/Unflattening 将 DataArray 放入 pytree 相对简单(除 data 外所有属性都是辅助属性),但我不知道如何检索 DeviceArray,甚至分配给它(我不能在这里使用 at[.].set(.))。

构建容器 class(具有 DeviceArray 成员)的替代方法要求手动实现 xarray 的所有相关功能。对于单个功能,例如标记索引,这是可能的但多余的。

按照 documentation.

中的示例,您可以通过将 xarray 类型注册为自定义 PyTree 节点来完成您想要的操作。

例如,它可能看起来像这样:

import jax.numpy as jnp
from jax import tree_util
import xarray as xr

tree_util.register_pytree_node(
    xr.DataArray,
    lambda x: ((x.data,), {"dims": x.dims, "coords": x.coords}),
    lambda kwds, args: xr.DataArray(*args, **kwds)
)

da = xr.DataArray(
    data=jnp.array([2.0,3.0]),
    dims=("var"),
    coords={"var": ["A","B"]},
)
print(da)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
#   * var      (var) <U1 'A' 'B'

data, tree = tree_util.tree_flatten(da)
print(data)
# [array([2., 3.], dtype=float32)]

da_reconstructed = tree_util.tree_unflatten(tree, data)
print(da_reconstructed)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
#   * var      (var) <U1 'A' 'B'

我认为除了最简单的情况外,这在任何情况下都不太可能按预期工作:例如,JAX 转换仅限于功能性代码、non-side-effecting 代码,而 JAX 数组是不可变的。 xarray 的操作通常违反这两个约束。

另一个问题:如果您希望在任何 JAX 函数中使用它,您必须以确保所有数据和元数据正确序列化的方式小心定义展平和展开函数,请注意,您可能 运行 遇到 xarray 的输入验证问题;有关详细信息,请参阅 Custom PyTrees and Initialization