PyTorch 不能 pickle lambda
PyTorch can't pickle lambda
我有一个使用自定义 LambdaLayer
的模型,如下所示:
class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun
def forward(self, x):
return self.fun(x)
class TorchCatEmbedding(LightningModule):
def __init__(self, start, end):
super(TorchCatEmbedding, self).__init__()
self.lb = LambdaLayer(lambda x: x[:, start:end])
self.embedding = torch.nn.Embedding(50, 5)
def forward(self, inputs):
o = self.lb(inputs).to(torch.int32)
o = self.embedding(o)
return o.squeeze()
该模型在 CPU 或 1 个 GPU 上运行完美。但是,当 运行 使用 PyTorch Lightning 超过 2 个 GPU 时,会发生此错误:
AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'
这里使用 lambda 函数的目的是给定一个 inputs
张量,我只想将 inputs[:, start:end]
传递给 embedding
层。
我的问题:
- 在这种情况下,是否有替代方法来使用 lambda?
- 如果不是,应该怎么做才能让 lambda 函数在这种情况下工作?
所以问题不在于 lambda 函数本身,而是 pickle 不适用于不仅仅是模块级函数的函数(pickle 处理函数的方式就像对某些模块级函数的引用一样)姓名)。所以,不幸的是,如果你需要捕获 start
和 end
参数,你将无法使用闭包,你通常只需要像这样的东西:
def function_maker(start, end):
def function(x):
return x[:, start:end]
return function
但是就 pickling 问题而言,这会让你回到开始的地方。
所以,试试这样的东西:
class Slicer:
def __init__(self, start, end):
self.start = start
self.end = end
def __call__(self, x):
return x[:, self.start:self.end])
那么你可以使用:
LambdaLayer(Slicer(start, end))
我不熟悉 PyTorch,但我很惊讶它不提供使用不同序列化后端的能力。例如,pathos/dill 项目可以 pickle 任意函数,而且通常更容易使用它。但我相信以上应该可以解决问题。
我有一个使用自定义 LambdaLayer
的模型,如下所示:
class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun
def forward(self, x):
return self.fun(x)
class TorchCatEmbedding(LightningModule):
def __init__(self, start, end):
super(TorchCatEmbedding, self).__init__()
self.lb = LambdaLayer(lambda x: x[:, start:end])
self.embedding = torch.nn.Embedding(50, 5)
def forward(self, inputs):
o = self.lb(inputs).to(torch.int32)
o = self.embedding(o)
return o.squeeze()
该模型在 CPU 或 1 个 GPU 上运行完美。但是,当 运行 使用 PyTorch Lightning 超过 2 个 GPU 时,会发生此错误:
AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'
这里使用 lambda 函数的目的是给定一个 inputs
张量,我只想将 inputs[:, start:end]
传递给 embedding
层。
我的问题:
- 在这种情况下,是否有替代方法来使用 lambda?
- 如果不是,应该怎么做才能让 lambda 函数在这种情况下工作?
所以问题不在于 lambda 函数本身,而是 pickle 不适用于不仅仅是模块级函数的函数(pickle 处理函数的方式就像对某些模块级函数的引用一样)姓名)。所以,不幸的是,如果你需要捕获 start
和 end
参数,你将无法使用闭包,你通常只需要像这样的东西:
def function_maker(start, end):
def function(x):
return x[:, start:end]
return function
但是就 pickling 问题而言,这会让你回到开始的地方。
所以,试试这样的东西:
class Slicer:
def __init__(self, start, end):
self.start = start
self.end = end
def __call__(self, x):
return x[:, self.start:self.end])
那么你可以使用:
LambdaLayer(Slicer(start, end))
我不熟悉 PyTorch,但我很惊讶它不提供使用不同序列化后端的能力。例如,pathos/dill 项目可以 pickle 任意函数,而且通常更容易使用它。但我相信以上应该可以解决问题。