在 v0.12.0 中使用 BasicLSTMCell 执行 lstm 时出错

Error performing lstm using BasicLSTMCell in v0.12.0

我的系统上安装了 v0.12.0,我在 运行 一个简单的 LSTM 时遇到了一些问题。在尝试让它工作时,我什至将我的代码缩减为 https://www.tensorflow.org/tutorials/recurrent/ 中的示例,但仍然面临同样的问题。我附上了我的代码片段和错误日志。 "lstm" 函数从一个卷积函数获取输入,该函数产生 40 帧序列的潜在表示(大小:1024)。

frames_batch_size = 40
batch_size = 20

def lstm(x, state_size=1024, initial_state=None, reuse=False):
    with tf.variable_scope("lstm") as lstm_scope:
        if reuse:
            tf.get_variable_scope().reuse_variables()

        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(state_size)
        state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")

        print(x.name, "-", x.get_shape())
        print(state.name, "-", state.get_shape())

        for i in range(frames_batch_size):
            output, state = lstm_cell(x[:, i], state)

        print(output.name, output.get_shape())
    return output

每个输出的张量形状为:

generator/convolution/conv_output:0 - (20, 40, 1024)
generator/lstm/lstm_state:0 - (20, 1024)

错误日志的片段是:

<ipython-input-13-35b652ff4acc> in lstm(x, state_size, initial_state, reuse)
     10 
     11         for i in range(frames_batch_size):
---> 12             output, state = lstm_cell(x[:, i], state)
     13 
     14         print(output.name, output.get_shape())

/usr/local/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py in __call__(self, inputs, state, scope)
    306       # Parameters of gates are concatenated into one multiply for efficiency.
    307       if self._state_is_tuple:
--> 308         c, h = state
    309       else:
    310         c, h = array_ops.split(1, 2, state)

/usr/local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
    508       TypeError: when invoked.
    509     """
--> 510     raise TypeError("'Tensor' object is not iterable.")
    511 
    512   def __bool__(self):

TypeError: 'Tensor' object is not iterable.

由于state_is_tuple默认为True,需要将lstm.zero_state(batch_size, tf.float32)传入state变量

替换,

state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")

state = initial_state if initial_state else lstm.zero_state(batch_size, tf.float32)

此外,确保将 LSTMStateTuple 个对象传递给 initial_state 参数。