如何在 tensorflow C/C++ 中输入和检索 LSTM 的状态

How to feed in and retrieve state of LSTM in tensorflow C/ C++

我想在 python 中构建和训练多层 LSTM 模型 (stateIsTuple=True),然后在 C++ 中加载和使用它。但是我很难弄清楚如何在 C++ 中提供和获取状态,主要是因为我没有可以引用的字符串名称。

例如我将初始状态放在命名范围内,例如

    with tf.name_scope('rnn_input_state'):
        self.initial_state = cell.zero_state(args.batch_size, tf.float32)

这出现在下图中,但我如何在 C++ 中提供给它们?

此外,如何在 C++ 中获取当前状态?我在 python 中尝试了下面的图构造代码,但我不确定这样做是否正确,因为 last_state 应该是张量元组,而不是单个张量(尽管我可以看到tensorboard 中的 last_state 节点是 2x2x50x128,这听起来像是连接状态,因为我有 2 层、128 个 rnn 大小、50 个小批量大小和 lstm 单元 - 具有 2 个状态向量)。

    with tf.name_scope('outputs'):
        outputs, last_state = legacy_seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None)
        output = tf.reshape(tf.concat(outputs, 1), [-1, args.rnn_size], name='output')

这就是它在 tensorboard 中的样子

我是否应该连接并拆分状态张量,以便只有一个状态张量进出?或者有更好的方法吗?

P.S。理想情况下,解决方案不会涉及对层数(或 rnn 大小)进行硬编码。所以我只能有四个字符串 input_node_name、output_node_name、input_state_name、output_state_name,其余的都是从那里派生的。

我通过手动将状态连接成一个张量来设法做到这一点。我不确定这是否明智,因为这就是 tensorflow 使用 处理状态的方式,但现在弃用它并切换到元组状态。我没有设置 state_is_tuple=False 并冒着我的代码很快就会过时的风险,而是添加了额外的操作来手动堆叠和取消堆叠单个张量的状态。也就是说,它在 python 和 C++ 中都运行良好。

关键代码是:

# setting up
zero_state = cell.zero_state(batch_size, tf.float32)
state_in = tf.identity(zero_state, name='state_in')         

# based on https://medium.com/@erikhallstrm/using-the-tensorflow-multilayered-lstm-api-f6e7da7bbe40#.zhg4zwteg
state_per_layer_list = tf.unstack(state_in, axis=0)
state_in_tuple = tuple(
    # TODO make this not hard-coded to LSTM
    [tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
    for idx in range(num_layers)]
)

outputs, state_out_tuple = legacy_seq2seq.rnn_decoder(inputs, state_in_tuple, cell, loop_function=loop if infer else None)
state_out = tf.identity(state_out_tuple, name='state_out')

# running (training or inference)
state = sess.run('state_in:0') # zero state

loop:
    feed = {'data_in:0': x, 'state_in:0': state}
    [y, state] = sess.run(['data_out:0', 'state_out:0'], feed)

如果有人需要的话,这是完整的代码 https://github.com/memo/char-rnn-tensorflow