tensorflow 的 tf.contrib.training.batch_sequences_with_states API 是如何工作的?

How does tensorflow's tf.contrib.training.batch_sequences_with_states API work?

我正在处理必须传递给 RNN 的长序列数据。要进行截断的 BPTT 和批处理,似乎有两种选择:

  1. 通过组合来自不同序列的相应 个片段创建一个批次。保留批次中每个序列的最终状态并将其传递给下一批次。
  2. 将每个序列视为一个小批量,序列中的片段成为批量的成员。保留一个段中最后一个时间步的状态,并将其传递到下一个段的第一个时间步。

我遇到了 tf.contrib.training.batch_sequences_with_states,这似乎是两者之一。该文档让我感到困惑,因此我想确定它以哪种方式生成批次。

我猜它是第一种方式。那是因为,如果以第二种方式进行批处理,那么我们就无法利用向量化的好处,因为要保留一个段的最后一个时间步到下一个段的第一个时间步之间的状态,RNN 应该处理一个一次按顺序标记。

问题:

这两种批处理策略中的哪一种在 tf.contrib.training.batch_sequences_with_states 中实现?

tf.contrib.training.batch_sequences_with_states 实现了前一种行为。每个小批量条目都是来自不同序列的片段(每个序列,可以由可变数量的片段组成,有一个唯一的键,这个键被传递到 batch_sequences_with_states)。当与 state_saving_rnn 一起使用时,每个段的最终状态被保存回一个特殊的存储容器,允许给定序列的下一个段在下一个 sess.run 处为 运行。最终片段为不同的序列释放了一个小批量插槽。