如何获取 jax.tree_flatten 对象的键?

How to get keys for jax.tree_flatten object?

考虑一个简单的嵌套配置

p = {'a': {'b': 1.0, 'c': 2.0}}
jax.tree_flatten(p)
p = {'a': {'b': 1.0, 'c': 2.0}}
jax.tree_flatten(p)
([1.0, 2.0], PyTreeDef({'a': {'b': *, 'c': *}}))

我怎样才能获得某种标签,例如 ['a.b'、'a.c'] 或任何其他符合 tree_flatten 顺序的合理标签?

jax.tree_util 没有内置此机制。在某种程度上,问题是 ill-posed:树扁平化适用于比嵌套字典更通用的 class 对象,如您的示例;您甚至可以为任意对象定义 pytree 展平(参见 https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees),我不清楚在这种一般情况下如何为展平对象构建标签。

如果您只关心嵌套的字典并且想要生成这些类型的扁平化标签,最好的办法可能是编写您自己的 Python 代码来构造扁平化的键和值;例如这样的事情可能会起作用:

p = {'a': {'b': 1.0, 'c': 2.0}}

def flatten(p, label=None):
  if isinstance(p, dict):
    for k, v in p.items():
      yield from flatten(v, k if label is None else f"{label}.{k}")
  else:
    yield (label, p)

print(dict(flatten(p)))
# {'a.b': 1.0, 'a.c': 2.0}