用 flax.nn.Module 实现循环神经网络
Implementing RNN with flax.nn.Module
我正在尝试使用 flax.nn.Module
实现一个基本的 RNN 单元。实现 RNN 单元的方程非常简单:
a_t = W * h_{t-1} + U * x_t + b
h_t = tanh(a_t)
o_t = V * h_t + c
其中 h_t 是时间 t 的更新状态,x_t 是输入,o_t 是输出,Tanh 是我们的激活函数。
我的代码使用 flax.nn.Module
,
class ElmanCell(nn.Module):
@nn.compact
def __call__(self, h, x):
nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
return nextState
我不知道如何实现参数 W、U 和 b。它们应该是 nn.Module 的属性吗?
试试这样的东西:
class RNNCell(nn.Module):
@nn.compact
def __call__(self, state, x):
# Wh @ h + Wx @ x + b can be efficiently computed
# by concatenating the vectors and then having a single dense layer
x = np.concatenate([state, x])
new_state = np.tanh(nn.Dense(state.shape[0])(x))
return new_state
这样就可以学习参数了。参见 https://schmit.github.io/jax/2021/06/20/jax-language-model-rnn.html
我正在尝试使用 flax.nn.Module
实现一个基本的 RNN 单元。实现 RNN 单元的方程非常简单:
a_t = W * h_{t-1} + U * x_t + b
h_t = tanh(a_t)
o_t = V * h_t + c
其中 h_t 是时间 t 的更新状态,x_t 是输入,o_t 是输出,Tanh 是我们的激活函数。
我的代码使用 flax.nn.Module
,
class ElmanCell(nn.Module):
@nn.compact
def __call__(self, h, x):
nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
return nextState
我不知道如何实现参数 W、U 和 b。它们应该是 nn.Module 的属性吗?
试试这样的东西:
class RNNCell(nn.Module):
@nn.compact
def __call__(self, state, x):
# Wh @ h + Wx @ x + b can be efficiently computed
# by concatenating the vectors and then having a single dense layer
x = np.concatenate([state, x])
new_state = np.tanh(nn.Dense(state.shape[0])(x))
return new_state
这样就可以学习参数了。参见 https://schmit.github.io/jax/2021/06/20/jax-language-model-rnn.html