在 v0.12.0 中使用 BasicLSTMCell 执行 lstm 时出错
Error performing lstm using BasicLSTMCell in v0.12.0
我的系统上安装了 v0.12.0,我在 运行 一个简单的 LSTM 时遇到了一些问题。在尝试让它工作时,我什至将我的代码缩减为 https://www.tensorflow.org/tutorials/recurrent/ 中的示例,但仍然面临同样的问题。我附上了我的代码片段和错误日志。 "lstm" 函数从一个卷积函数获取输入,该函数产生 40 帧序列的潜在表示(大小:1024)。
frames_batch_size = 40
batch_size = 20
def lstm(x, state_size=1024, initial_state=None, reuse=False):
with tf.variable_scope("lstm") as lstm_scope:
if reuse:
tf.get_variable_scope().reuse_variables()
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(state_size)
state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")
print(x.name, "-", x.get_shape())
print(state.name, "-", state.get_shape())
for i in range(frames_batch_size):
output, state = lstm_cell(x[:, i], state)
print(output.name, output.get_shape())
return output
每个输出的张量形状为:
generator/convolution/conv_output:0 - (20, 40, 1024)
generator/lstm/lstm_state:0 - (20, 1024)
错误日志的片段是:
<ipython-input-13-35b652ff4acc> in lstm(x, state_size, initial_state, reuse)
10
11 for i in range(frames_batch_size):
---> 12 output, state = lstm_cell(x[:, i], state)
13
14 print(output.name, output.get_shape())
/usr/local/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py in __call__(self, inputs, state, scope)
306 # Parameters of gates are concatenated into one multiply for efficiency.
307 if self._state_is_tuple:
--> 308 c, h = state
309 else:
310 c, h = array_ops.split(1, 2, state)
/usr/local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
508 TypeError: when invoked.
509 """
--> 510 raise TypeError("'Tensor' object is not iterable.")
511
512 def __bool__(self):
TypeError: 'Tensor' object is not iterable.
由于state_is_tuple
默认为True
,需要将lstm.zero_state(batch_size, tf.float32)
传入state
变量
替换,
state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")
与
state = initial_state if initial_state else lstm.zero_state(batch_size, tf.float32)
此外,确保将 LSTMStateTuple
个对象传递给 initial_state
参数。
我的系统上安装了 v0.12.0,我在 运行 一个简单的 LSTM 时遇到了一些问题。在尝试让它工作时,我什至将我的代码缩减为 https://www.tensorflow.org/tutorials/recurrent/ 中的示例,但仍然面临同样的问题。我附上了我的代码片段和错误日志。 "lstm" 函数从一个卷积函数获取输入,该函数产生 40 帧序列的潜在表示(大小:1024)。
frames_batch_size = 40
batch_size = 20
def lstm(x, state_size=1024, initial_state=None, reuse=False):
with tf.variable_scope("lstm") as lstm_scope:
if reuse:
tf.get_variable_scope().reuse_variables()
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(state_size)
state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")
print(x.name, "-", x.get_shape())
print(state.name, "-", state.get_shape())
for i in range(frames_batch_size):
output, state = lstm_cell(x[:, i], state)
print(output.name, output.get_shape())
return output
每个输出的张量形状为:
generator/convolution/conv_output:0 - (20, 40, 1024)
generator/lstm/lstm_state:0 - (20, 1024)
错误日志的片段是:
<ipython-input-13-35b652ff4acc> in lstm(x, state_size, initial_state, reuse)
10
11 for i in range(frames_batch_size):
---> 12 output, state = lstm_cell(x[:, i], state)
13
14 print(output.name, output.get_shape())
/usr/local/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py in __call__(self, inputs, state, scope)
306 # Parameters of gates are concatenated into one multiply for efficiency.
307 if self._state_is_tuple:
--> 308 c, h = state
309 else:
310 c, h = array_ops.split(1, 2, state)
/usr/local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
508 TypeError: when invoked.
509 """
--> 510 raise TypeError("'Tensor' object is not iterable.")
511
512 def __bool__(self):
TypeError: 'Tensor' object is not iterable.
由于state_is_tuple
默认为True
,需要将lstm.zero_state(batch_size, tf.float32)
传入state
变量
替换,
state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")
与
state = initial_state if initial_state else lstm.zero_state(batch_size, tf.float32)
此外,确保将 LSTMStateTuple
个对象传递给 initial_state
参数。