用于 xarray 的 JAX pytree
JAX pytree for xarray
是否可以基于 xarray.DataArray
创建一个 pytree(用于 JAX)?
关键问题似乎是 xarray.DataArray
的 data
属性被构造为输入数组的 numpy.ndarray
视图(例如 DeviceArray
)
import jax.numpy as jnp
import xarray as xr
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
>>> type(da.data)
numpy.ndarray
Flattening/Unflattening 将 DataArray
放入 pytree 相对简单(除 data
外所有属性都是辅助属性),但我不知道如何检索 DeviceArray
,甚至分配给它(我不能在这里使用 at[.].set(.)
)。
构建容器 class(具有 DeviceArray
成员)的替代方法要求手动实现 xarray 的所有相关功能。对于单个功能,例如标记索引,这是可能的但多余的。
按照 documentation.
中的示例,您可以通过将 xarray 类型注册为自定义 PyTree 节点来完成您想要的操作。
例如,它可能看起来像这样:
import jax.numpy as jnp
from jax import tree_util
import xarray as xr
tree_util.register_pytree_node(
xr.DataArray,
lambda x: ((x.data,), {"dims": x.dims, "coords": x.coords}),
lambda kwds, args: xr.DataArray(*args, **kwds)
)
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
print(da)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
data, tree = tree_util.tree_flatten(da)
print(data)
# [array([2., 3.], dtype=float32)]
da_reconstructed = tree_util.tree_unflatten(tree, data)
print(da_reconstructed)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
我认为除了最简单的情况外,这在任何情况下都不太可能按预期工作:例如,JAX 转换仅限于功能性代码、non-side-effecting 代码,而 JAX 数组是不可变的。 xarray 的操作通常违反这两个约束。
另一个问题:如果您希望在任何 JAX 函数中使用它,您必须以确保所有数据和元数据正确序列化的方式小心定义展平和展开函数,请注意,您可能 运行 遇到 xarray 的输入验证问题;有关详细信息,请参阅 Custom PyTrees and Initialization。
是否可以基于 xarray.DataArray
创建一个 pytree(用于 JAX)?
关键问题似乎是 xarray.DataArray
的 data
属性被构造为输入数组的 numpy.ndarray
视图(例如 DeviceArray
)
import jax.numpy as jnp
import xarray as xr
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
>>> type(da.data)
numpy.ndarray
Flattening/Unflattening 将 DataArray
放入 pytree 相对简单(除 data
外所有属性都是辅助属性),但我不知道如何检索 DeviceArray
,甚至分配给它(我不能在这里使用 at[.].set(.)
)。
构建容器 class(具有 DeviceArray
成员)的替代方法要求手动实现 xarray 的所有相关功能。对于单个功能,例如标记索引,这是可能的但多余的。
按照 documentation.
中的示例,您可以通过将 xarray 类型注册为自定义 PyTree 节点来完成您想要的操作。例如,它可能看起来像这样:
import jax.numpy as jnp
from jax import tree_util
import xarray as xr
tree_util.register_pytree_node(
xr.DataArray,
lambda x: ((x.data,), {"dims": x.dims, "coords": x.coords}),
lambda kwds, args: xr.DataArray(*args, **kwds)
)
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
print(da)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
data, tree = tree_util.tree_flatten(da)
print(data)
# [array([2., 3.], dtype=float32)]
da_reconstructed = tree_util.tree_unflatten(tree, data)
print(da_reconstructed)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
我认为除了最简单的情况外,这在任何情况下都不太可能按预期工作:例如,JAX 转换仅限于功能性代码、non-side-effecting 代码,而 JAX 数组是不可变的。 xarray 的操作通常违反这两个约束。
另一个问题:如果您希望在任何 JAX 函数中使用它,您必须以确保所有数据和元数据正确序列化的方式小心定义展平和展开函数,请注意,您可能 运行 遇到 xarray 的输入验证问题;有关详细信息,请参阅 Custom PyTrees and Initialization。