nn.Sequential 的参数太少或太多
Either too little or too many arguments for a nn.Sequential
我是 PyTorch 的新手,所以请原谅我的愚蠢问题。
我在编码器对象的初始化中定义了一个 nn.Sequential,如下所示:
self.list_of_blocks = [EncoderBlock(n_features, n_heads, n_hidden, dropout) for _ in range(n_blocks)]
self.blocks = nn.Sequential(*self.list_of_blocks)
EncoderBlock 的 forward 是这样的
def forward(self, x, mask):
在我的编码器的 forward() 中,我尝试做:
z0 = self.blocks(z0, mask)
我希望 nn.Sequential 将这两个参数传递给各个块。
但是,我得到
TypeError: forward() takes 2 positional arguments but 3 were given
当我尝试时:
z0 = self.blocks(z0)
我得到(可以理解):
TypeError: forward() takes 2 positional arguments but only 1 was given
当我不使用 nn.Sequential 而只是一个接一个地执行 EncoderBlock 时,它有效:
for i in range(self.n_blocks):
z0 = self.list_of_blocks[i](z0, mask)
问题:我做错了什么,在这种情况下如何正确使用nn.Sequential?
顺序通常不适用于多个输入和输出。
这是一个经常讨论的话题,见PyTorch forum and GitHub issues #1908 or #9979。
您可以定义自己的顺序版本。假设所有编码器块的掩码都相同(例如,在 Transformer 网络中),您可以:
class MaskedSequential(nn.Sequential):
def forward(self, x, mask):
for module in self._modules.values():
x = module(x, mask)
return inputs
或者,如果您的 EncoderBlock
s return 元组,您可以使用 GitHub issues:
之一中建议的更通用的解决方案
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
我是 PyTorch 的新手,所以请原谅我的愚蠢问题。
我在编码器对象的初始化中定义了一个 nn.Sequential,如下所示:
self.list_of_blocks = [EncoderBlock(n_features, n_heads, n_hidden, dropout) for _ in range(n_blocks)]
self.blocks = nn.Sequential(*self.list_of_blocks)
EncoderBlock 的 forward 是这样的
def forward(self, x, mask):
在我的编码器的 forward() 中,我尝试做:
z0 = self.blocks(z0, mask)
我希望 nn.Sequential 将这两个参数传递给各个块。
但是,我得到
TypeError: forward() takes 2 positional arguments but 3 were given
当我尝试时:
z0 = self.blocks(z0)
我得到(可以理解):
TypeError: forward() takes 2 positional arguments but only 1 was given
当我不使用 nn.Sequential 而只是一个接一个地执行 EncoderBlock 时,它有效:
for i in range(self.n_blocks):
z0 = self.list_of_blocks[i](z0, mask)
问题:我做错了什么,在这种情况下如何正确使用nn.Sequential?
顺序通常不适用于多个输入和输出。
这是一个经常讨论的话题,见PyTorch forum and GitHub issues #1908 or #9979。
您可以定义自己的顺序版本。假设所有编码器块的掩码都相同(例如,在 Transformer 网络中),您可以:
class MaskedSequential(nn.Sequential):
def forward(self, x, mask):
for module in self._modules.values():
x = module(x, mask)
return inputs
或者,如果您的 EncoderBlock
s return 元组,您可以使用 GitHub issues:
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs