通过多次前向传播反向传播
Backpropagating through multiple forward passes
在通常的反向传播中,我们向前传播一次,计算梯度,然后将它们应用于更新权重。但是假设我们希望前向传播两次,并通过两次进行反向传播,然后才应用梯度(先跳过)。
假设如下:
x = tf.Variable([2.])
w = tf.Variable([4.])
with tf.GradientTape(persistent=True) as tape:
w.assign(w * x)
y = w * w # w^2 * x
print(tape.gradient(y, x)) # >>None
从docs开始,tf.Variable
是一个有状态对象,块梯度,权重是tf.Variable
s.
示例是可区分的硬注意力(与 RL 相对),或者只是在后续前向传递中在层之间传递隐藏状态,如下图所示。 TF 和 Keras 都没有 API 级对有状态梯度的支持,包括 RNN
s,它只保留一个有状态的状态张量;梯度不会超过一批。
如何实现?
我们需要精心申请tf.while_loop
;来自 help(TensorArray)
:
This class is meant to be used with dynamic iteration primitives such as while_loop
and map_fn
. It supports gradient back-propagation via special "flow" control flow dependencies.
因此,我们试图编写一个循环,以便将我们要反向传播的所有输出都写入 TensorArray
。完成此操作的代码及其高级描述如下。底部是一个验证示例。
描述:
- 代码借鉴自
K.rnn
,为简单性和相关性重写了代码
- 为了更好地理解,我建议检查
K.rnn
、SimpleRNNCell.call
, and RNN.call
。
model_rnn
为了案例 3 做了一些不必要的检查;将 link 清洁版
- 思路如下:我们先自下而上,然后从左到右遍历网络,然后写整个前向传给单
TensorArray
下单tf.while_loop
;这确保 TF 在整个过程中缓存张量运算以进行反向传播。
from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops, tensor_array_ops
from tensorflow.python.framework import ops
def model_rnn(model, inputs, states=None, swap_batch_timestep=True):
def step_function(inputs, states):
out = model([inputs, *states], training=True)
output, new_states = (out if isinstance(out, (tuple, list)) else
(out, states))
return output, new_states
def _swap_batch_timestep(input_t):
# (samples, timesteps, channels) -> (timesteps, samples, channels)
# iterating dim0 to feed (samples, channels) slices expected by RNN
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return array_ops.transpose(input_t, axes)
if swap_batch_timestep:
inputs = nest.map_structure(_swap_batch_timestep, inputs)
if states is None:
states = (tf.zeros(model.inputs[0].shape, dtype='float32'),)
initial_states = states
input_ta, output_ta, time, time_steps_t = _process_args(model, inputs)
def _step(time, output_ta_t, *states):
current_input = input_ta.read(time)
output, new_states = step_function(current_input, tuple(states))
flat_state = nest.flatten(states)
flat_new_state = nest.flatten(new_states)
for state, new_state in zip(flat_state, flat_new_state):
if isinstance(new_state, ops.Tensor):
new_state.set_shape(state.shape)
output_ta_t = output_ta_t.write(time, output)
new_states = nest.pack_sequence_as(initial_states, flat_new_state)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta) + tuple(initial_states),
cond=lambda time, *_: tf.math.less(time, time_steps_t))
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = output_ta.stack()
return outputs, new_states
def _process_args(model, inputs):
time_steps_t = tf.constant(inputs.shape[0], dtype='int32')
# assume single-input network (excluding states)
input_ta = tensor_array_ops.TensorArray(
dtype=inputs.dtype,
size=time_steps_t,
tensor_array_name='input_ta_0').unstack(inputs)
# assume single-input network (excluding states)
# if having states, infer info from non-state nodes
output_ta = tensor_array_ops.TensorArray(
dtype=model.outputs[0].dtype,
size=time_steps_t,
element_shape=model.outputs[0].shape,
tensor_array_name='output_ta_0')
time = tf.constant(0, dtype='int32', name='time')
return input_ta, output_ta, time, time_steps_t
示例和验证:
案例设计:我们将相同的输入输入两次,从而可以进行某些有状态与无状态的比较;结果也适用于不同的输入。
- 案例0:对照;其他情况必须与此匹配。
- 情况一:失败;梯度不匹配,即使输出和损失匹配。馈送减半序列时反向传播失败。
- 案例 2:梯度与案例 1 匹配。看起来我们只使用了一个
tf.while_loop
,但 SimpleRNN 在 3 个时间步中使用了它自己的一个,并写入被丢弃的 TensorArray
;这不行。解决方法是自己实现 SimpleRNN 逻辑。
- 案例三:完美匹配
请注意,没有状态 RNN 单元这样的东西;有状态是在 RNN
基础 class 中实现的,我们在 model_rnn
中重新创建了它。这同样是任何其他层的处理方式 - 每次前向传递一次进给一步切片。
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, SimpleRNN, SimpleRNNCell
from tensorflow.keras.models import Model
def reset_seeds():
random.seed(0)
np.random.seed(1)
tf.compat.v1.set_random_seed(2) # graph-level seed
tf.random.set_seed(3) # global seed
def print_report(case, model, outs, loss, tape, idx=1):
print("\nCASE #%s" % case)
print("LOSS", loss)
print("GRADS:\n", tape.gradient(loss, model.layers[idx].weights[0]))
print("OUTS:\n", outs)
#%%# Make data ###############################################################
reset_seeds()
x0 = y0 = tf.constant(np.random.randn(2, 3, 4))
x0_2 = y0_2 = tf.concat([x0, x0], axis=1)
x00 = y00 = tf.stack([x0, x0], axis=0)
#%%# Case 0: Complete forward pass; control case #############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model0 = Model(ipt, out)
model0.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs = model0(x0_2, training=True)
loss = model0.compiled_loss(y0_2, outs)
print_report(0, model0, outs, loss, tape)
#%%# Case 1: Two passes, stateful RNN, direct feeding ########################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model1 = Model(ipt, out)
model1.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs0 = model1(x0, training=True)
tape.watch(outs0) # cannot even diff otherwise
outs1 = model1(x0, training=True)
tape.watch(outs1)
outs = tf.concat([outs0, outs1], axis=1)
tape.watch(outs)
loss = model1.compiled_loss(y0_2, outs)
print_report(1, model1, outs, loss, tape)
#%%# Case 2: Two passes, stateful RNN, model_rnn #############################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model2 = Model(ipt, out)
model2.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs, _ = model_rnn(model2, x00, swap_batch_timestep=False)
outs = tf.concat(list(outs), axis=1)
loss = model2.compiled_loss(y0_2, outs)
print_report(2, model2, outs, loss, tape)
#%%# Case 3: Single pass, stateless RNN, model_rnn ###########################
reset_seeds()
ipt = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model3 = Model([ipt, sipt], [out, state])
model3.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs, _ = model_rnn(model3, x0_2)
outs = tf.transpose(outs, (1, 0, 2))
loss = model3.compiled_loss(y0_2, outs)
print_report(3, model3, outs, loss, tape, idx=2)
垂直流:我们已经验证了水平,timewise-backpropagation;垂直呢?
为此,我们实现了堆叠式有状态 RNN;结果如下。我机器上的所有输出,here.
我们特此验证了垂直和水平状态反向传播。这可用于实现具有正确反向传播的任意复杂的前向传播逻辑。应用示例 .
#%%# Case 4: Complete forward pass; control case ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
x = SimpleRNN(4, return_sequences=True)(ipt)
out = SimpleRNN(4, return_sequences=True)(x)
model4 = Model(ipt, out)
model4.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
outs = model4(x0_2, training=True)
loss = model4.compiled_loss(y0_2, outs)
print("=" * 80)
print_report(4, model4, outs, loss, tape, idx=1)
print_report(4, model4, outs, loss, tape, idx=2)
#%%# Case 5: Two passes, stateless RNN; model_rnn ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model5a = Model(ipt, out)
model5a.compile('sgd', 'mse')
ipt = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model5b = Model([ipt, sipt], [out, state])
model5b.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
outs = model5a(x0_2, training=True)
outs, _ = model_rnn(model5b, outs)
outs = tf.transpose(outs, (1, 0, 2))
loss = model5a.compiled_loss(y0_2, outs)
print_report(5, model5a, outs, loss, tape)
print_report(5, model5b, outs, loss, tape, idx=2)
在通常的反向传播中,我们向前传播一次,计算梯度,然后将它们应用于更新权重。但是假设我们希望前向传播两次,并通过两次进行反向传播,然后才应用梯度(先跳过)。
假设如下:
x = tf.Variable([2.])
w = tf.Variable([4.])
with tf.GradientTape(persistent=True) as tape:
w.assign(w * x)
y = w * w # w^2 * x
print(tape.gradient(y, x)) # >>None
从docs开始,tf.Variable
是一个有状态对象,块梯度,权重是tf.Variable
s.
示例是可区分的硬注意力(与 RL 相对),或者只是在后续前向传递中在层之间传递隐藏状态,如下图所示。 TF 和 Keras 都没有 API 级对有状态梯度的支持,包括 RNN
s,它只保留一个有状态的状态张量;梯度不会超过一批。
如何实现?
我们需要精心申请tf.while_loop
;来自 help(TensorArray)
:
This class is meant to be used with dynamic iteration primitives such as
while_loop
andmap_fn
. It supports gradient back-propagation via special "flow" control flow dependencies.
因此,我们试图编写一个循环,以便将我们要反向传播的所有输出都写入 TensorArray
。完成此操作的代码及其高级描述如下。底部是一个验证示例。
描述:
- 代码借鉴自
K.rnn
,为简单性和相关性重写了代码 - 为了更好地理解,我建议检查
K.rnn
、SimpleRNNCell.call
, andRNN.call
。 model_rnn
为了案例 3 做了一些不必要的检查;将 link 清洁版- 思路如下:我们先自下而上,然后从左到右遍历网络,然后写整个前向传给单
TensorArray
下单tf.while_loop
;这确保 TF 在整个过程中缓存张量运算以进行反向传播。
from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops, tensor_array_ops
from tensorflow.python.framework import ops
def model_rnn(model, inputs, states=None, swap_batch_timestep=True):
def step_function(inputs, states):
out = model([inputs, *states], training=True)
output, new_states = (out if isinstance(out, (tuple, list)) else
(out, states))
return output, new_states
def _swap_batch_timestep(input_t):
# (samples, timesteps, channels) -> (timesteps, samples, channels)
# iterating dim0 to feed (samples, channels) slices expected by RNN
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return array_ops.transpose(input_t, axes)
if swap_batch_timestep:
inputs = nest.map_structure(_swap_batch_timestep, inputs)
if states is None:
states = (tf.zeros(model.inputs[0].shape, dtype='float32'),)
initial_states = states
input_ta, output_ta, time, time_steps_t = _process_args(model, inputs)
def _step(time, output_ta_t, *states):
current_input = input_ta.read(time)
output, new_states = step_function(current_input, tuple(states))
flat_state = nest.flatten(states)
flat_new_state = nest.flatten(new_states)
for state, new_state in zip(flat_state, flat_new_state):
if isinstance(new_state, ops.Tensor):
new_state.set_shape(state.shape)
output_ta_t = output_ta_t.write(time, output)
new_states = nest.pack_sequence_as(initial_states, flat_new_state)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta) + tuple(initial_states),
cond=lambda time, *_: tf.math.less(time, time_steps_t))
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = output_ta.stack()
return outputs, new_states
def _process_args(model, inputs):
time_steps_t = tf.constant(inputs.shape[0], dtype='int32')
# assume single-input network (excluding states)
input_ta = tensor_array_ops.TensorArray(
dtype=inputs.dtype,
size=time_steps_t,
tensor_array_name='input_ta_0').unstack(inputs)
# assume single-input network (excluding states)
# if having states, infer info from non-state nodes
output_ta = tensor_array_ops.TensorArray(
dtype=model.outputs[0].dtype,
size=time_steps_t,
element_shape=model.outputs[0].shape,
tensor_array_name='output_ta_0')
time = tf.constant(0, dtype='int32', name='time')
return input_ta, output_ta, time, time_steps_t
示例和验证:
案例设计:我们将相同的输入输入两次,从而可以进行某些有状态与无状态的比较;结果也适用于不同的输入。
- 案例0:对照;其他情况必须与此匹配。
- 情况一:失败;梯度不匹配,即使输出和损失匹配。馈送减半序列时反向传播失败。
- 案例 2:梯度与案例 1 匹配。看起来我们只使用了一个
tf.while_loop
,但 SimpleRNN 在 3 个时间步中使用了它自己的一个,并写入被丢弃的TensorArray
;这不行。解决方法是自己实现 SimpleRNN 逻辑。 - 案例三:完美匹配
请注意,没有状态 RNN 单元这样的东西;有状态是在 RNN
基础 class 中实现的,我们在 model_rnn
中重新创建了它。这同样是任何其他层的处理方式 - 每次前向传递一次进给一步切片。
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, SimpleRNN, SimpleRNNCell
from tensorflow.keras.models import Model
def reset_seeds():
random.seed(0)
np.random.seed(1)
tf.compat.v1.set_random_seed(2) # graph-level seed
tf.random.set_seed(3) # global seed
def print_report(case, model, outs, loss, tape, idx=1):
print("\nCASE #%s" % case)
print("LOSS", loss)
print("GRADS:\n", tape.gradient(loss, model.layers[idx].weights[0]))
print("OUTS:\n", outs)
#%%# Make data ###############################################################
reset_seeds()
x0 = y0 = tf.constant(np.random.randn(2, 3, 4))
x0_2 = y0_2 = tf.concat([x0, x0], axis=1)
x00 = y00 = tf.stack([x0, x0], axis=0)
#%%# Case 0: Complete forward pass; control case #############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model0 = Model(ipt, out)
model0.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs = model0(x0_2, training=True)
loss = model0.compiled_loss(y0_2, outs)
print_report(0, model0, outs, loss, tape)
#%%# Case 1: Two passes, stateful RNN, direct feeding ########################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model1 = Model(ipt, out)
model1.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs0 = model1(x0, training=True)
tape.watch(outs0) # cannot even diff otherwise
outs1 = model1(x0, training=True)
tape.watch(outs1)
outs = tf.concat([outs0, outs1], axis=1)
tape.watch(outs)
loss = model1.compiled_loss(y0_2, outs)
print_report(1, model1, outs, loss, tape)
#%%# Case 2: Two passes, stateful RNN, model_rnn #############################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model2 = Model(ipt, out)
model2.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs, _ = model_rnn(model2, x00, swap_batch_timestep=False)
outs = tf.concat(list(outs), axis=1)
loss = model2.compiled_loss(y0_2, outs)
print_report(2, model2, outs, loss, tape)
#%%# Case 3: Single pass, stateless RNN, model_rnn ###########################
reset_seeds()
ipt = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model3 = Model([ipt, sipt], [out, state])
model3.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
outs, _ = model_rnn(model3, x0_2)
outs = tf.transpose(outs, (1, 0, 2))
loss = model3.compiled_loss(y0_2, outs)
print_report(3, model3, outs, loss, tape, idx=2)
垂直流:我们已经验证了水平,timewise-backpropagation;垂直呢?
为此,我们实现了堆叠式有状态 RNN;结果如下。我机器上的所有输出,here.
我们特此验证了垂直和水平状态反向传播。这可用于实现具有正确反向传播的任意复杂的前向传播逻辑。应用示例
#%%# Case 4: Complete forward pass; control case ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
x = SimpleRNN(4, return_sequences=True)(ipt)
out = SimpleRNN(4, return_sequences=True)(x)
model4 = Model(ipt, out)
model4.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
outs = model4(x0_2, training=True)
loss = model4.compiled_loss(y0_2, outs)
print("=" * 80)
print_report(4, model4, outs, loss, tape, idx=1)
print_report(4, model4, outs, loss, tape, idx=2)
#%%# Case 5: Two passes, stateless RNN; model_rnn ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model5a = Model(ipt, out)
model5a.compile('sgd', 'mse')
ipt = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model5b = Model([ipt, sipt], [out, state])
model5b.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
outs = model5a(x0_2, training=True)
outs, _ = model_rnn(model5b, outs)
outs = tf.transpose(outs, (1, 0, 2))
loss = model5a.compiled_loss(y0_2, outs)
print_report(5, model5a, outs, loss, tape)
print_report(5, model5b, outs, loss, tape, idx=2)