为什么我在 LSTM 中添加 relu 激活后得到 Nan?

Why am I getting Nan after adding relu activation in LSTM?

我有一个简单的 LSTM 网络,大致如下所示:

lstm_activation = tf.nn.relu

cells_fw = [LSTMCell(num_units=100, activation=lstm_activation), 
            LSTMCell(num_units=10, activation=lstm_activation)]

stacked_cells_fw = MultiRNNCell(cells_fw)

_, states = tf.nn.dynamic_rnn(cell=stacked_cells_fw,
                              inputs=embedding_layer,
                              sequence_length=features['length'],
                              dtype=tf.float32)

output_states = [s.h for s in states]
states = tf.concat(output_states, 1)

我的问题是。当我不使用激活 (activation=None) 或使用 tanh 时,一切正常,但当我切换 relu 时,我一直得到 "NaN loss during training",这是为什么?它是 100% 可重现的。

当您在 lstm cell 中使用 relu activation function 时,可以保证单元格的所有输出以及单元格状态都将严格为 >= 0。因此,您的渐变变得非常大并且正在爆炸。例如,运行 以下代码片段并观察到输出永远不会 < 0.

X = np.random.rand(4,3,2)
lstm_cell = tf.nn.rnn_cell.LSTMCell(5, activation=tf.nn.relu)
hidden_states, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=X, dtype=tf.float64)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(hidden_states))