将来自 tensorflow dynamic_rnn 的输出馈送到后续层

Feeding outputs from tensorflow dynamic_rnn to subsequent layer

我已经开始在 tensorflow 中使用 RNN,并且了解了一般原理,但实现的某些方面不太清楚。

我的理解:假设我正在训练一个序列到序列网络,其中输入与输出大小相同(这可能类似于在每个时间步预测一段文本中的下一个字符).我的循环层使用 LSTM 单元,之后我想要一个全连接层来为预测添加更多深度。

在静态 RNN 中,按照 TF 惯例,您应该在时间维度上拆分输入数据并将其作为列表提供给 static_rnn 方法,如下所示:

import tensorflow as tf

num_input_features = 32
num_output_features = 32

lstm_size = 128
max_seq_len = 5

# input/output:
x = tf.placeholder(tf.float32, [None, max_seq_len, num_input_features])

x_series = tf.unstack(x, axis=1) # a list of length max_seq_len

# recurrent layer:
lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
rnn_outputs, final_state = tf.nn.static_rnn(lstm_cell, x_series, dtype=tf.float32)

这为您提供了一个输出列表,每个时间步一个。然后如果你想在每一步对 RNN 的输出做一些额外的计算,你可以只对输出列表的每个元素这样做:

# output layer:

w = tf.Variable(tf.random_normal([lstm_size, num_output_features]))
b = tf.Variable(tf.random_normal([num_output_features]))

z_series = [tf.matmul(out, w) + b for out in rnn_outputs]
yhat_series = [tf.nn.tanh(z) for z in z_series]

然后我可以再次叠加 yhat_series 并将其与一些标签 y 进行比较以获得我的成本函数。

这是我没有得到的:在动态 RNN 中,您提供给 dynamic_rnn 方法的输入是一个具有自己的时间维度(默认为轴 1)的张量:

# input/output:
x = tf.placeholder(tf.float32, [None, max_seq_len, num_input_features])

# x_series = tf.unstack(x, axis=1) # dynamic RNN does not need this

# recurrent layer:
lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
dyn_rnn_outputs, dyn_final_state = tf.nn.dynamic_rnn(lstm_cell, x, dtype=tf.float32)

那么dyn_rnn_output就不是一个列表,而是一个形状为(?,max_seq_len,lstm_size)的张量。将此张量馈送到后续致密层的最佳方法是什么?我无法将 RNN 输出乘以我的权重矩阵,拆开 RNN 输出感觉就像是 dynamic_rnn API 旨在避免的尴尬黑客攻击。

我是否缺少解决此问题的好方法?

任何试图解决这个问题的人的更新:

有一个张量流函数,tf.contrib.rnn.OutputProjectionWrapper,似乎专门用于将密集层附加到 RNN 单元的输出,但将其包装为 RNN 单元本身的一部分,然后您可以调用 tf.nn.dynamic_rnn:

展开
lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
proj = tf.contrib.rnn.OutputProjectionWrapper(lstm_cell, num_output_features)
dyn_rnn_outputs, dyn_final_state = tf.nn.dynamic_rnn(proj, x, dtype=tf.float32)

但更一般地说,如果您想对 RNN 的输出进行操作,通常的做法似乎是通过跨批次和时间维度展开来重塑 rnn_outputs,然后对该张量执行操作,并将它们回滚以获得最终输出。