如何从 tensorflow RNN 中选取最后一个有效输出值
How to pick the last valid output values from tensorflow RNN
我正在用不同长度的序列批次训练 LSTM 单元。 tf.nn.rnn
有一个非常方便的参数sequence_length
,但是调用它后,我不知道如何选择批次中每个项目的最后一个时间步对应的输出行。
我的代码基本如下:
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
是一个列表,其中包含每个时间步长的 LSTM 输出。但是,我的批次中的每个项目都有不同的长度,因此我想创建一个张量,其中包含对我的批次中的每个项目有效的最后一个 LSTM 输出。
如果我可以使用 numpy 索引,我会这样做:
all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
但事实证明暂时 begin tensorflow 不支持它(我知道 feature request)。
那么,我怎样才能得到这些值呢?
这不是最好的解决方案,但您可以评估您的输出,然后使用 numpy 索引获取结果并从中创建一个张量变量?在 tensorflow 获得此功能之前,它可能会作为权宜之计。例如
all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'})
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs)
danijar 在我在问题中链接的功能请求页面上发布了一个更可接受的解决方法。它不需要评估张量,这是一个很大的优势。
我让它可以与 tensorflow 0.8 一起使用。这是代码:
def extract_last_relevant(outputs, length):
"""
Args:
outputs: [Tensor(batch_size, output_neurons)]: A list containing the output
activations of each in the batch for each time step as returned by
tensorflow.models.rnn.rnn.
length: Tensor(batch_size): The used sequence length of each example in the
batch with all later time steps being zeros. Should be of type tf.int32.
Returns:
Tensor(batch_size, output_neurons): The last relevant output activation for
each example in the batch.
"""
output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])
# Query shape.
batch_size = tf.shape(output)[0]
max_length = int(output.get_shape()[1])
num_neurons = int(output.get_shape()[2])
# Index into flattened array as a workaround.
index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, num_neurons])
relevant = tf.gather(flat, index)
return relevant
如果您只对最后一个有效输出感兴趣,您可以通过 tf.nn.rnn()
返回的状态检索它,考虑到它始终是一个元组 (c, h),其中 c 是最后一个状态,h 是最后的输出。当状态为 LSTMStateTuple
时,您可以使用以下代码片段(在 tensorflow 0.12 中工作):
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
last_output = state[1]
我正在用不同长度的序列批次训练 LSTM 单元。 tf.nn.rnn
有一个非常方便的参数sequence_length
,但是调用它后,我不知道如何选择批次中每个项目的最后一个时间步对应的输出行。
我的代码基本如下:
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
是一个列表,其中包含每个时间步长的 LSTM 输出。但是,我的批次中的每个项目都有不同的长度,因此我想创建一个张量,其中包含对我的批次中的每个项目有效的最后一个 LSTM 输出。
如果我可以使用 numpy 索引,我会这样做:
all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
但事实证明暂时 begin tensorflow 不支持它(我知道 feature request)。
那么,我怎样才能得到这些值呢?
这不是最好的解决方案,但您可以评估您的输出,然后使用 numpy 索引获取结果并从中创建一个张量变量?在 tensorflow 获得此功能之前,它可能会作为权宜之计。例如
all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'})
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs)
danijar 在我在问题中链接的功能请求页面上发布了一个更可接受的解决方法。它不需要评估张量,这是一个很大的优势。
我让它可以与 tensorflow 0.8 一起使用。这是代码:
def extract_last_relevant(outputs, length):
"""
Args:
outputs: [Tensor(batch_size, output_neurons)]: A list containing the output
activations of each in the batch for each time step as returned by
tensorflow.models.rnn.rnn.
length: Tensor(batch_size): The used sequence length of each example in the
batch with all later time steps being zeros. Should be of type tf.int32.
Returns:
Tensor(batch_size, output_neurons): The last relevant output activation for
each example in the batch.
"""
output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])
# Query shape.
batch_size = tf.shape(output)[0]
max_length = int(output.get_shape()[1])
num_neurons = int(output.get_shape()[2])
# Index into flattened array as a workaround.
index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, num_neurons])
relevant = tf.gather(flat, index)
return relevant
如果您只对最后一个有效输出感兴趣,您可以通过 tf.nn.rnn()
返回的状态检索它,考虑到它始终是一个元组 (c, h),其中 c 是最后一个状态,h 是最后的输出。当状态为 LSTMStateTuple
时,您可以使用以下代码片段(在 tensorflow 0.12 中工作):
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
last_output = state[1]