升级代码 rnn.static_bidirectional_rnn 以适应 tensorflow 2.0 API
upgrade code rnn.static_bidirectional_rnn to fit with tensorflow 2.0 API
import tensorflow as tf
from tf.contrib import rnn
lstm_f = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_b = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
blstm_out, state_f, state_b = rnn.static_bidirectional_rnn(lstm_f, lstm_b, x, dtype=tf.float32)
上面的代码适用于 tensorflow 1.x,但是我觉得很难找到使用 tensorflow 2.0 重写此代码的方法 API。
我知道我应该从 tf.keras.layers.LSTMCell() 开始,但我不知道什么是 API 函数以适应 2 个 LSTMCell 实例作为输入。
Keras 等价于您的代码段是
lstm = keras.layers.LSTM(n_hidden, unit_forget_bias=True, unroll=True)
keras.layers.Bidirectional(lstm)
请注意,虽然 Keras 具有 LSTMCell
的实现,但您可能希望使用 LSTM
instead, which is not just a cell but a fully unrolled RNN operating on the whole sequence at once. By default, the RNN is unrolled dynamically via a while loop, we force it to be static (in TF 1.X terms) by passing unroll=True
. Finally, keras.layers.Bidirectional
包装器使 RNN 成为双向的。
import tensorflow as tf
from tf.contrib import rnn
lstm_f = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_b = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
blstm_out, state_f, state_b = rnn.static_bidirectional_rnn(lstm_f, lstm_b, x, dtype=tf.float32)
上面的代码适用于 tensorflow 1.x,但是我觉得很难找到使用 tensorflow 2.0 重写此代码的方法 API。
我知道我应该从 tf.keras.layers.LSTMCell() 开始,但我不知道什么是 API 函数以适应 2 个 LSTMCell 实例作为输入。
Keras 等价于您的代码段是
lstm = keras.layers.LSTM(n_hidden, unit_forget_bias=True, unroll=True)
keras.layers.Bidirectional(lstm)
请注意,虽然 Keras 具有 LSTMCell
的实现,但您可能希望使用 LSTM
instead, which is not just a cell but a fully unrolled RNN operating on the whole sequence at once. By default, the RNN is unrolled dynamically via a while loop, we force it to be static (in TF 1.X terms) by passing unroll=True
. Finally, keras.layers.Bidirectional
包装器使 RNN 成为双向的。