如何在每个纪元后重置张量流中的 GRU 状态
How to reset the state of a GRU in tensorflow after every epoch
我正在使用 tensorflow GRU 单元来实现 RNN。我将上述视频与最长 5 分钟的视频一起使用。因此,由于下一个状态会自动输入 GRU,我该如何在每个纪元后手动重置 RNN 的状态。换句话说,我希望训练开始时的初始状态始终为 0。这是我的代码片段:
with tf.variable_scope('GRU'):
latent_var = tf.reshape(latent_var, shape=[batch_size, time_steps, latent_dim])
cell = tf.nn.rnn_cell.GRUCell(cell_size)
H, C = tf.nn.dynamic_rnn(cell, latent_var, dtype=tf.float32)
H = tf.reshape(H, [batch_size, cell_size])
....
非常感谢任何帮助!
使用 tf.nn.dynamic_rnn
的 initial_state
参数:
initial_state
: (optional) An initial state for the RNN. If
cell.state_size
is an integer, this must be a Tensor of appropriate
type and shape [batch_size, cell.state_size]
. If cell.state_siz
e is a
tuple, this should be a tuple of tensors having shapes [batch_size, s] for s in cell.state_size
.
文档中的改编示例:
# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
另请注意,尽管 initial_state
不是占位符,您也可以将值提供给它。因此,如果希望在一个纪元内保留状态,但在纪元开始时从零开始,你可以这样做:
# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)
# Start with a zero vector and update it
cur_state = zero_state
for batch in get_batches():
cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})
我正在使用 tensorflow GRU 单元来实现 RNN。我将上述视频与最长 5 分钟的视频一起使用。因此,由于下一个状态会自动输入 GRU,我该如何在每个纪元后手动重置 RNN 的状态。换句话说,我希望训练开始时的初始状态始终为 0。这是我的代码片段:
with tf.variable_scope('GRU'):
latent_var = tf.reshape(latent_var, shape=[batch_size, time_steps, latent_dim])
cell = tf.nn.rnn_cell.GRUCell(cell_size)
H, C = tf.nn.dynamic_rnn(cell, latent_var, dtype=tf.float32)
H = tf.reshape(H, [batch_size, cell_size])
....
非常感谢任何帮助!
使用 tf.nn.dynamic_rnn
的 initial_state
参数:
initial_state
: (optional) An initial state for the RNN. Ifcell.state_size
is an integer, this must be a Tensor of appropriate type and shape[batch_size, cell.state_size]
. Ifcell.state_siz
e is a tuple, this should be a tuple of tensors having shapes[batch_size, s] for s in cell.state_size
.
文档中的改编示例:
# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
另请注意,尽管 initial_state
不是占位符,您也可以将值提供给它。因此,如果希望在一个纪元内保留状态,但在纪元开始时从零开始,你可以这样做:
# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)
# Start with a zero vector and update it
cur_state = zero_state
for batch in get_batches():
cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})