PyTorch 不能 pickle lambda

PyTorch can't pickle lambda

我有一个使用自定义 LambdaLayer 的模型,如下所示:

class LambdaLayer(LightningModule):
    def __init__(self, fun):
        super(LambdaLayer, self).__init__()
        self.fun = fun

    def forward(self, x):
        return self.fun(x)


class TorchCatEmbedding(LightningModule):
    def __init__(self, start, end):
        super(TorchCatEmbedding, self).__init__()
        self.lb = LambdaLayer(lambda x: x[:, start:end])
        self.embedding = torch.nn.Embedding(50, 5)

    def forward(self, inputs):
        o = self.lb(inputs).to(torch.int32)
        o = self.embedding(o)
        return o.squeeze()

该模型在 CPU 或 1 个 GPU 上运行完美。但是,当 运行 使用 PyTorch Lightning 超过 2 个 GPU 时,会发生此错误:

AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'

这里使用 lambda 函数的目的是给定一个 inputs 张量,我只想将 inputs[:, start:end] 传递给 embedding 层。

我的问题:

所以问题不在于 lambda 函数本身,而是 pickle 不适用于不仅仅是模块级函数的函数(pickle 处理函数的方式就像对某些模块级函数的引用一样)姓名)。所以,不幸的是,如果你需要捕获 startend 参数,你将无法使用闭包,你通常只需要像这样的东西:

def function_maker(start, end):
    def function(x):
        return x[:, start:end]
    return function

但是就 pickling 问题而言,这会让你回到开始的地方。

所以,试试这样的东西:

class Slicer:
    def __init__(self, start, end):
        self.start = start
        self.end = end
    def __call__(self, x):
        return x[:, self.start:self.end])

那么你可以使用:

LambdaLayer(Slicer(start, end))

我不熟悉 PyTorch,但我很惊讶它不提供使用不同序列化后端的能力。例如,pathos/dill 项目可以 pickle 任意函数,而且通常更容易使用它。但我相信以上应该可以解决问题。