TensorFlow:Android 中推理之间的 RNN 初始化状态

TensorFlow: initializing state for RNN between inferences in Android

我们在 Android 上有一个工作的 TensorFlow 网络 (graphdef) 运行,我注意到随着时间的推移推理结果往往是相关的。也就是说,如果返回标签 A,即使输入数据切换到应生成 B 标签的数据,也往往会及时出现 A 流。最终,结果将切换到 B,但似乎存在滞后,表明 RNN 在推理调用之间是有状态的。我们的网络正在使用 RNN/LSTMs.

cellLSTM    = tf.nn.rnn_cell.BasicLSTMCell(nHidden)
cellsLSTM   = tf.nn.rnn_cell.MultiRNNCell([cellLSTM] * 2)
RNNout, RNNstates = tf.nn.rnn(cellsLSTM, Xin)

我想知道是否需要在推理调用之间重新初始化 RNN 状态。我会注意到 TensorFlowInferenceInterface.java 界面中没有这样的方法。我想可以将 RNN 单元初始化节点插入到可以用节点值激活的图中(使用 FillNodeInt 或类似的)。

所以我的问题是:Tensorflow 中 RNN/LSTMs 的最佳实践是什么。是否需要清除推理之间的状态?如果可以,怎么做呢?

Does one need to clear state between inferences?

我认为这取决于 RNN 的训练方式以及您的使用方式。但是,我猜想无论是否重置状态,网络都可以很好地工作。

how does one do it?

评估与初始状态关联的每个张量的初始化操作。

虽然我无法对 RNN 状态初始化的一般做法发表评论,但以下是我们设法强制执行初始状态定义的方法。问题在于,虽然批量大小确实是训练集的常量参数,但它不是测试集。测试集总是占数据语料库的 20%,因此它的大小随着语料库的每次变化而不同。
解决方案是为 batchsize 创建一个新变量:

batch_size_T  = tf.shape(Xin)[0]

其中 Xin 是大小为 [b x m x n] 的输入张量,其中 b 是批量大小,m x n 是训练帧的大小。 Xin 是从 feed_dict.

输入的

初始状态可以定义为:

initial_state = lstm_cells.zero_state(batch_size_T, tf.float32) 

最后,根据新的动态 RNN 定义 RNN:

outputs, state = tf.nn.dynamic_rnn(cell=lstm_cells, inputs=Xin, dtype=tf.float32, initial_state=initial_state)