Tensorflow:如何从 rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell 获取所有变量

Tensorflow: How to get all variables from rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

我有一个设置,我需要在使用 tf.initialize_all_variables() 的主要初始化之后初始化 LSTM。 IE。我想打电话给tf.initialize_variables([var_list])

有没有办法收集两者的所有内部可训练变量:

这样我就可以初始化JUST这些参数?

我想要这个的主要原因是因为我不想重新初始化之前训练过的一些值。

解决您的问题的最简单方法是使用变量作用域。范围内变量的名称将以其名称为前缀。这是一个简短的片段:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

它与 MultiRNNCell 的工作方式相同。

编辑:将 tf.trainable_variables 更改为 tf.all_variables()

您也可以使用 tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

(部分复制自 Rafal 的回答)

注意最后一行相当于拉法尔代码中的列表理解。

基本上,tensorflow 存储了一个全局变量集合,可以通过 tf.all_variables()tf.get_collection(tf.GraphKeys.VARIABLES) 获取。如果您在 tf.get_collection() 函数中指定 scope(作用域名称),那么您只会获取作用域在指定作用域下的集合中的张量(在本例中为变量)。

编辑: 您还可以使用 tf.GraphKeys.TRAINABLE_VARIABLES 仅获取可训练变量。但是由于 vanilla BasicLSTMCell 没有初始化任何不可训练的变量,所以两者在功能上是等价的。如需默认图形集合的完整列表,请查看 this