Tensorflow RNN 细胞具有不同的权重

Tensorflow RNN cells have different weights

我正在尝试根据此处的教程在 tensorflow 中编写一个简单的 RNN:https://danijar.com/introduction-to-recurrent-networks-in-tensorflow/ (我使用的是简单的 RNN 单元而不是 GRU,并且没有使用 dropout)。

我很困惑,因为我的序列中的不同 RNN 单元似乎被分配了不同的权重。如果我运行下面的代码

import tensorflow as tf

seq_length = 3
n_h = 100   # Number of hidden units
n_x = 26    # Size of input layer
n_y = 26    # Size of output layer

inputs = tf.placeholder(tf.float32, [None, seq_length, n_x])

cells = []
for _ in range(seq_length):
    cell = tf.contrib.rnn.BasicRNNCell(n_h)
    cells.append(cell)
multi_rnn_cell = tf.contrib.rnn.MultiRNNCell(cells)

initial_state = tf.placeholder(tf.float32, [None, n_h])

outputs_h, output_final_state = tf.nn.dynamic_rnn(multi_rnn_cell, inputs, dtype=tf.float32)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print('Trainable variables:')
for v in tf.trainable_variables():
    print(v)

如果我在 python 3 中 运行 这个,我得到以下输出:

Trainable variables:
<tf.Variable 'rnn/multi_rnn_cell/cell_0/basic_rnn_cell/kernel:0' shape=(126, 100) dtype=float32_ref>
<tf.Variable 'rnn/multi_rnn_cell/cell_0/basic_rnn_cell/bias:0' shape=(100,) dtype=float32_ref>
<tf.Variable 'rnn/multi_rnn_cell/cell_1/basic_rnn_cell/kernel:0' shape=(200, 100) dtype=float32_ref>
<tf.Variable 'rnn/multi_rnn_cell/cell_1/basic_rnn_cell/bias:0' shape=(100,) dtype=float32_ref>
<tf.Variable 'rnn/multi_rnn_cell/cell_2/basic_rnn_cell/kernel:0' shape=(200, 100) dtype=float32_ref>
<tf.Variable 'rnn/multi_rnn_cell/cell_2/basic_rnn_cell/bias:0' shape=(100,) dtype=float32_ref>

首先,这不是我想要的 - RNN 需要在每一层从输入到隐藏和隐藏到隐藏具有相同的权重!

其次,我不太明白为什么我得到所有这些单独的变量。如果我查看 source code for rnn cells,它看起来像 BasicRNNCell 应该调用 _linear,它应该查找是否有名称为 _WEIGHTS_VARIABLE_NAME 的变量(全局设置为 "kernel"),如果是这样,请使用它。我不明白 "kernel" 是如何装饰成 "rnn/multi_rnn_cell/cell_0/basic_rnn_cell/kernel:0".

如果有人能解释我做错了什么,我将不胜感激。

注意区分两个不同的东西:你的递归神经网络的层数和这个 RNN 被时间反向传播算法展开以处理序列长度的次数。

在您的代码中:

  • MultiCellRNN 负责创建一个 3 层 RNN(您正在那里创建三个 LAYERS,而 MultiCellRNN 只是一个包装器,以便于处理它们)
  • tf.nn.dynamic_rnn 负责根据您的序列长度展开此三层网络的次数