如何从俳句中的参数(pytree)获取参数? (jax框架)

How does one get a parameter from a params (pytree) in haiku? (jax framework)

例如,您设置了一个具有参数的模块。但是如果你想在损失中对某些东西进行正则化,那么模式是什么?

import jax.numpy as jnp
import jax
def loss(params, x, y):
   l = jnp.sum((y - mlp.apply(params, x)) ** 2)
   w = hk.get_params(params, 'w') # does not work like this
   l += jnp.sum(w ** w)
   return l

示例中缺少一些模式。

params本质上是一个只读字典,所以你可以把参数当作字典来获取参数的值:

print(params['w'])

如果要更新参数,不能就地进行,必须先将其转换为可变字典:

params_mutable = hk.data_structures.to_mutable_dict(params)
params_mutable['w'] = 3.14
params_new = hk.data_structures.to_immutable_dict(params_mutable)