梯度检查点返回值

Gradient Checkpointing returning values

我有一个检查点回调函数(即 custom_dec),return 是一个张量和一个字典。但似乎这个函数没有 return 字典(或其他数据类型),而只有张量。解决方法是什么,因为我要检查点的模块是 return 张量,加上数据类型作为字典:

def custom_dec(self, module):
        def custom_forward(*inputs):
            output = module(inputs[0], inputs[1],
                            encoder_attn_mask=inputs[2],
                            decoder_padding_mask=inputs[3],
                            layer_state=inputs[4],
                            causal_mask=inputs[5],
                            output_attentions=inputs[6],
                            )
            # output[2] is a python dictionary
            return output[0], output[2]

检查点调用如下:

x, layer_past = \
                checkpoint.checkpoint(
                    self.custom_dec(decoder_layer),
                    x,
                    encoder_hidden_states,
                    encoder_padding_mask,
                    decoder_padding_mask,
                    layer_state,
                    decoder_causal_mask,
                    output_attentions,
                )

错误:

TypeError: CheckpointFunctionBackward.forward: expected Variable (got dictionary) for return value 1

讨论了类似的情况 here

你可以做的是将字典转换成某种张量形式。我遇到了一个错误,它是由 torch.utils.checkpoint 不接受的输入列表引起的。我的解决方案是将列表中的张量作为独立张量传递,并在 custom_forward.

中形成一个列表

我不知道你的字典的形式(例如,如果每个键总是有一个值),但你可以想出一个适用于你的字典的字典-张量互换方案。