pytorch在顺序模型中跳过连接
pytorch skip connection in a sequential model
我正在努力解决顺序模型中的跳过连接问题。使用功能 API 我会做一些简单的事情(快速示例,可能不是 100% 语法正确但应该明白这个想法):
x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)
我现在正在使用顺序模型并尝试做类似的事情,创建一个跳过连接,将第一个 conv 层的激活一直带到最后一个 convTranspose。我看了一下实现的 U-net 架构 here 有点混乱,它做了这样的事情:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
这不就是在顺序模型中很好地、顺序地添加层吗? down
conv 后面跟着 submodule
(递归地添加内层),然后连接到 up
,这是 upconv 层。我可能遗漏了一些关于 Sequential
API 工作原理的重要信息,但是从 U-NET 截取的代码实际上是如何实现跳过的?
您的观察是正确的,但您可能错过了 UnetSkipConnectionBlock.forward()
的定义(UnetSkipConnectionBlock
是定义您共享的 U-Net 块的 Module
),这可能会澄清这一点实施:
(来自 pytorch-CycleGAN-and-pix2pix/models/networks.py#L259
)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
# ...
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
最后一行是关键(适用于所有内部块)。跳过层只需将输入 x
和(递归)块输出 self.model(x)
与您提到的操作列表 self.model
连接起来即可完成 - 因此与 Functional
你写的代码。
我正在努力解决顺序模型中的跳过连接问题。使用功能 API 我会做一些简单的事情(快速示例,可能不是 100% 语法正确但应该明白这个想法):
x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)
我现在正在使用顺序模型并尝试做类似的事情,创建一个跳过连接,将第一个 conv 层的激活一直带到最后一个 convTranspose。我看了一下实现的 U-net 架构 here 有点混乱,它做了这样的事情:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
这不就是在顺序模型中很好地、顺序地添加层吗? down
conv 后面跟着 submodule
(递归地添加内层),然后连接到 up
,这是 upconv 层。我可能遗漏了一些关于 Sequential
API 工作原理的重要信息,但是从 U-NET 截取的代码实际上是如何实现跳过的?
您的观察是正确的,但您可能错过了 UnetSkipConnectionBlock.forward()
的定义(UnetSkipConnectionBlock
是定义您共享的 U-Net 块的 Module
),这可能会澄清这一点实施:
(来自 pytorch-CycleGAN-and-pix2pix/models/networks.py#L259
)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
# ...
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
最后一行是关键(适用于所有内部块)。跳过层只需将输入 x
和(递归)块输出 self.model(x)
与您提到的操作列表 self.model
连接起来即可完成 - 因此与 Functional
你写的代码。