nn.Sequential 的参数太少或太多

Either too little or too many arguments for a nn.Sequential

我是 PyTorch 的新手,所以请原谅我的愚蠢问题。

我在编码器对象的初始化中定义了一个 nn.Sequential,如下所示:

self.list_of_blocks = [EncoderBlock(n_features, n_heads, n_hidden, dropout) for _ in range(n_blocks)]
self.blocks = nn.Sequential(*self.list_of_blocks)

EncoderBlock 的 forward 是这样的

def forward(self, x, mask):

在我的编码器的 forward() 中,我尝试做:

z0 = self.blocks(z0, mask)

我希望 nn.Sequential 将这两个参数传递给各个块。

但是,我得到

TypeError: forward() takes 2 positional arguments but 3 were given

当我尝试时:

z0 = self.blocks(z0)

我得到(可以理解):

TypeError: forward() takes 2 positional arguments but only 1 was given

当我不使用 nn.Sequential 而只是一个接一个地执行 EncoderBlock 时,它有效:

for i in range(self.n_blocks):
     z0 = self.list_of_blocks[i](z0, mask)

问题:我做错了什么,在这种情况下如何正确使用nn.Sequential?

顺序通常不适用于多个输入和输出。

这是一个经常讨论的话题,见PyTorch forum and GitHub issues #1908 or #9979

您可以定义自己的顺序版本。假设所有编码器块的掩码都相同(例如,在 Transformer 网络中),您可以:

class MaskedSequential(nn.Sequential):
    def forward(self, x, mask):
        for module in self._modules.values():
            x = module(x, mask)
        return inputs

或者,如果您的 EncoderBlocks return 元组,您可以使用 GitHub issues:

之一中建议的更通用的解决方案
class MySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs