尝试为 RNN 重用权重时出错
Error while trying to reuse weights for RNN
我正在尝试将双向 LSTM 权重重新用于 2 个非常相似的计算,但出现错误并且不知道我做错了什么。
我有一个基本模块的 class :
class BasicAttn(object):
def __init__(self, keep_prob, value_vec_size):
self.rnn_cell_fw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
self.rnn_cell_fw = DropoutWrapper(self.rnn_cell_fw, input_keep_prob=self.keep_prob)
self.rnn_cell_bw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
self.rnn_cell_bw = DropoutWrapper(self.rnn_cell_bw, input_keep_prob=self.keep_prob)
def build_graph(self, values, values_mask, keys):
blended_reps = compute_blended_reps()
with tf.variable_scope('BasicAttn_BRNN', reuse=True):
(fw_out, bw_out), _ =
tf.nn.bidirectional_dynamic_rnn(self.rnn_cell_fw, self.rnn_cell_bw, blended_reps, dtype=tf.float32, scope='BasicAttn_BRNN')
然后,在构建图形时调用该模块
attn_layer_start = BasicAttn(...)
blended_reps_start = attn_layer_start.build_graph(...)
attn_layer_end = BasicAttn(...)
blended_reps_end = attn_layer_end.build_graph(...)
但是我收到错误消息说 TensorFlow 无法重用 RNN?
ValueError: Variable QAModel/BasicAttn_BRNN/BasicAttn_BRNN/fw/lstm_cell/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope
代码比较多,删掉了觉得不需要的部分
reuse=True
表示变量之前已使用 reuse=False
创建,因此每个 tf.get_variable
(在您的情况下是在 LSTM 接口后面抽象的)期望变量已经存在。
要有一种模式,在变量不存在时创建变量,否则重新使用,您需要设置 reuse=tf.AUTO_REUSE
(如错误消息所示)。
所以用 reuse=tf.AUTO_REUSE
替换所有出现的 reuse=True
这是文档:https://www.tensorflow.org/api_docs/python/tf/variable_scope
我正在尝试将双向 LSTM 权重重新用于 2 个非常相似的计算,但出现错误并且不知道我做错了什么。 我有一个基本模块的 class :
class BasicAttn(object):
def __init__(self, keep_prob, value_vec_size):
self.rnn_cell_fw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
self.rnn_cell_fw = DropoutWrapper(self.rnn_cell_fw, input_keep_prob=self.keep_prob)
self.rnn_cell_bw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
self.rnn_cell_bw = DropoutWrapper(self.rnn_cell_bw, input_keep_prob=self.keep_prob)
def build_graph(self, values, values_mask, keys):
blended_reps = compute_blended_reps()
with tf.variable_scope('BasicAttn_BRNN', reuse=True):
(fw_out, bw_out), _ =
tf.nn.bidirectional_dynamic_rnn(self.rnn_cell_fw, self.rnn_cell_bw, blended_reps, dtype=tf.float32, scope='BasicAttn_BRNN')
然后,在构建图形时调用该模块
attn_layer_start = BasicAttn(...)
blended_reps_start = attn_layer_start.build_graph(...)
attn_layer_end = BasicAttn(...)
blended_reps_end = attn_layer_end.build_graph(...)
但是我收到错误消息说 TensorFlow 无法重用 RNN?
ValueError: Variable QAModel/BasicAttn_BRNN/BasicAttn_BRNN/fw/lstm_cell/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope
代码比较多,删掉了觉得不需要的部分
reuse=True
表示变量之前已使用 reuse=False
创建,因此每个 tf.get_variable
(在您的情况下是在 LSTM 接口后面抽象的)期望变量已经存在。
要有一种模式,在变量不存在时创建变量,否则重新使用,您需要设置 reuse=tf.AUTO_REUSE
(如错误消息所示)。
所以用 reuse=tf.AUTO_REUSE
reuse=True
这是文档:https://www.tensorflow.org/api_docs/python/tf/variable_scope