torch.nn.Sequential 的设计块在提供输入时存在问题

torch.nn.Sequential of designed blocks problem in giving inputs

我设计了一个class,它是一个网络块,它的前向有三个输入:x、logdet、reverse,并有两个输出。 例如,当我调用此 class 并使用它时一切正常,例如:

x = torch.Tensor(np.random.rand(2, 48, 8, 8))
net = Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
a, _ = net(x, reverse = False)

但是当我想按Sequential来使用的时候(因为我需要一个接一个的多块),问题是这样的:

x = torch.Tensor(np.random.rand(2, 48, 8, 8))
conv1_network = nn.Sequential(
    Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
        )
conv1_network(x, reverse = False)

我的错误是: TypeError: forward() got an unexpected keyword argument 'reverse' 这是不正常的,因为正如我们在第一部分中看到的那样,我在 Block 中的正向输入中有反向。 我期待找到一种方法将一些块相互连接,例如这是一个块

class Block(nn.Module):
    def __init__(self, num_channels):
        super(InvConv, self).__init__()
        self.num_channels = num_channels

        # Initialize with a random orthogonal matrix
        w_init = np.random.randn(num_channels, num_channels)
        w_init = np.linalg.qr(w_init)[0].astype(np.float32)
        self.weight = nn.Parameter(torch.from_numpy(w_init))

    def forward(self, x, logdet, reverse=False):
        ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)

        if reverse:
            weight = torch.inverse(self.weight.double()).float()
            logdet = logdet - ldj
        else:
            weight = self.weight
            logdet = logdet + ldj

        weight = weight.view(self.num_channels, self.num_channels, 1, 1)
        z = F.conv2d(x, weight)

        return z, logdet

而我的目的是在一个for中将多个Block以Sequential的方式相互连接(因为我不能在我的工作中使用相同的Block,所以我需要不同的卷积来制作深度网络)

features = []
for i in range(10):
   self.features.append(Block(num_channels = 48))

然后我想像这样使用它们

self.features(x, logdet = 0, reverse = False)

您表示您的 Block nn.Module 有一个 reverse 选项。但是 nn.Sequential 没有,所以 conv1_network(x, reverse=False) 无效,因为 conv1_network 不是 Block.

默认情况下,您无法将 kwargs 传递给 nn.Sequential 内的图层。但是,您可以继承 nn.Sequential 并自己完成。类似于:

class BlockSequence(nn.Sequential):
    def forward(self, input, **kwargs):
        for module in self:
            options = kwargs if isinstance(module, Block) else {}
            input = module(input, **options)
        return input

这样,您可以创建一个包含 Block 的序列(以及可选的非 Block 模块):

>>> blocks = []
>>> for i in range(10):
...     self.blocks.append(Block(num_channels=48))

>>> blocks = BlockSequence(*blocks)

然后您将能够使用 reverse 关键字参数调用 blocks,调用时会转发给每个潜在的 Block 子模块:

>>> blocks(x, logdet=0, reverse=False)