梯度检查点返回值
Gradient Checkpointing returning values
我有一个检查点回调函数(即 custom_dec),return 是一个张量和一个字典。但似乎这个函数没有 return 字典(或其他数据类型),而只有张量。解决方法是什么,因为我要检查点的模块是 return 张量,加上数据类型作为字典:
def custom_dec(self, module):
def custom_forward(*inputs):
output = module(inputs[0], inputs[1],
encoder_attn_mask=inputs[2],
decoder_padding_mask=inputs[3],
layer_state=inputs[4],
causal_mask=inputs[5],
output_attentions=inputs[6],
)
# output[2] is a python dictionary
return output[0], output[2]
检查点调用如下:
x, layer_past = \
checkpoint.checkpoint(
self.custom_dec(decoder_layer),
x,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
layer_state,
decoder_causal_mask,
output_attentions,
)
错误:
TypeError: CheckpointFunctionBackward.forward: expected Variable (got
dictionary) for return value 1
讨论了类似的情况 here。
你可以做的是将字典转换成某种张量形式。我遇到了一个错误,它是由 torch.utils.checkpoint
不接受的输入列表引起的。我的解决方案是将列表中的张量作为独立张量传递,并在 custom_forward
.
中形成一个列表
我不知道你的字典的形式(例如,如果每个键总是有一个值),但你可以想出一个适用于你的字典的字典-张量互换方案。
我有一个检查点回调函数(即 custom_dec),return 是一个张量和一个字典。但似乎这个函数没有 return 字典(或其他数据类型),而只有张量。解决方法是什么,因为我要检查点的模块是 return 张量,加上数据类型作为字典:
def custom_dec(self, module):
def custom_forward(*inputs):
output = module(inputs[0], inputs[1],
encoder_attn_mask=inputs[2],
decoder_padding_mask=inputs[3],
layer_state=inputs[4],
causal_mask=inputs[5],
output_attentions=inputs[6],
)
# output[2] is a python dictionary
return output[0], output[2]
检查点调用如下:
x, layer_past = \
checkpoint.checkpoint(
self.custom_dec(decoder_layer),
x,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
layer_state,
decoder_causal_mask,
output_attentions,
)
错误:
TypeError: CheckpointFunctionBackward.forward: expected Variable (got dictionary) for return value 1
讨论了类似的情况 here。
你可以做的是将字典转换成某种张量形式。我遇到了一个错误,它是由 torch.utils.checkpoint
不接受的输入列表引起的。我的解决方案是将列表中的张量作为独立张量传递,并在 custom_forward
.
我不知道你的字典的形式(例如,如果每个键总是有一个值),但你可以想出一个适用于你的字典的字典-张量互换方案。