是否可以在自动编码器之后添加可训练的过滤器?

Is it possible to add a trainable filter after an autoencoder?

所以我正在构建一个带有自动编码器的降噪器。这个想法是,在计算我的损失之前(在自动编码器之后),我将经验维纳滤波器应用于图像的纹理贴图并将其添加回我的自动编码器输出(添加回“丢失的细节”)。我用 PyTorch 编写了这个过滤器。

我的第一次尝试是将过滤器添加到我的自动编码器的前向函数的末尾。我可以训练这个网络,它在训练中通过我的过滤器反向传播。但是,如果我打印我的网络,过滤器没有列出,并且 torchsummary 在计算参数时不包括它。

这让我想到我只是在训练自动编码器,而我的过滤器每次都以相同的方式过滤而不是学习。

我想做的事情可行吗?

下面是我的自动编码器:

class AutoEncoder(nn.Module):
"""Autoencoder simple implementation """
def __init__(self):
    super(AutoEncoder, self).__init__()
    # Encoder
    # conv layer
    self.block1 = nn.Sequential(
        nn.Conv2d(1, 48, 3, padding=1),
        nn.Conv2d(48, 48, 3, padding=1),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(48),
        nn.LeakyReLU(0.1)

    )
    self.block2 = nn.Sequential(
        nn.Conv2d(48, 48, 3, padding=1),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(48),
        nn.LeakyReLU(0.1)
    )
    self.block3 = nn.Sequential(
        nn.Conv2d(48, 48, 3, padding=1),
        nn.ConvTranspose2d(48, 48, 2, 2, output_padding=1),
        nn.BatchNorm2d(48),
        nn.LeakyReLU(0.1)
    )
    self.block4 = nn.Sequential(
        nn.Conv2d(96, 96, 3, padding=1),
        nn.Conv2d(96, 96, 3, padding=1),
        nn.ConvTranspose2d(96, 96, 2, 2),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1)
    )
    self.block5 = nn.Sequential(
        nn.Conv2d(144, 96, 3, padding=1),
        nn.Conv2d(96, 96, 3, padding=1),
        nn.ConvTranspose2d(96, 96, 2, 2),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1)
    )
    self.block6 = nn.Sequential(
        nn.Conv2d(97, 64, 3, padding=1),
        nn.BatchNorm2d(64),
        nn.Conv2d(64, 32, 3, padding=1),
        nn.BatchNorm2d(32),
        nn.Conv2d(32, 1, 3, padding=1),
        nn.LeakyReLU(0.1)
    )

    # self.blockNorm = nn.Sequential(
    #     nn.BatchNorm2d(1),
    #     nn.LeakyReLU(0.1)
    # )

def forward(self, x):
    # torch.autograd.set_detect_anomaly(True)
    # print("input: ", x.shape)
    pool1 = self.block1(x)
    # print("pool1: ", pool1.shape)
    pool2 = self.block2(pool1)
    # print("pool2: ", pool2.shape)
    pool3 = self.block2(pool2)
    # print("pool3: ", pool3.shape)
    pool4 = self.block2(pool3)
    # print("pool4: ", pool4.shape)
    pool5 = self.block2(pool4)
    # print("pool5: ", pool5.shape)
    upsample5 = self.block3(pool5)
    # print("upsample5: ", upsample5.shape)
    concat5 = torch.cat((upsample5, pool4), 1)
    # print("concat5: ", concat5.shape)
    upsample4 = self.block4(concat5)
    # print("upsample4: ", upsample4.shape)
    concat4 = torch.cat((upsample4, pool3), 1)
    # print("concat4: ", concat4.shape)
    upsample3 = self.block5(concat4)
    # print("upsample3: ", upsample3.shape)
    concat3 = torch.cat((upsample3, pool2), 1)
    # print("concat3: ", concat3.shape)
    upsample2 = self.block5(concat3)
    # print("upsample2: ", upsample2.shape)
    concat2 = torch.cat((upsample2, pool1), 1)
    # print("concat2: ", concat2.shape)
    upsample1 = self.block5(concat2)
    # print("upsample1: ", upsample1.shape)
    concat1 = torch.cat((upsample1, x), 1)
    # print("concat1: ", concat1.shape)
    output = self.block6(concat1)

    t_map = x - output

    for i in range(4):
        tensor = t_map[i, :, :, :]                 # Take each item in batch separately. Could account for this in Wiener instead

        tensor = torch.squeeze(tensor)              # Squeeze for Wiener input format

        tensor = wiener_3d(tensor, 0.05, 10)        # Apply Wiener with specified std and block size
        tensor = torch.unsqueeze(tensor, 0)         # unsqueeze to put back into block
        t_map[i, :, :, :] = tensor                  # put back into block

    filtered_output = output + t_map
    return filtered_output

末尾的 for 循环是将滤镜应用于批处理中的每个图像。我知道这是不可并行化的,所以如果有人对此有想法,我将不胜感激。如果有帮助,我可以 post 'wiener 3d()' 过滤器功能,只是想保持 post 简短。

我尝试定义一个自定义层 class,其中包含过滤器,但我很快就迷路了。

如有任何帮助,我们将不胜感激!

如果您只想将您的维纳滤波器变成一个模块,则可以执行以下操作:

class WienerFilter(T.nn.Module):
    def __init__(self, param_a=0.05, param_b=10):
        super(WienerFilter, self).__init__()
        # This can be accessed like any other member via self.param_a
        self.register_parameter("param_a", T.nn.Parameter(T.tensor(param_a)))
        self.param_b = param_b

    def forward(self, input):
        for i in range(4):
            tensor = input[i]                
            tensor = torch.squeeze(tensor)
            tensor = wiener_3d(tensor, self.param_a, self.param_b)
            tensor = torch.unsqueeze(tensor, 0)
            input[i] = tensor 
        return input  

您可以通过添加一行来应用它

self.wiener_filter = WienerFilter()

在您的自动编码器的初始化函数中。

在前面然后你通过用

替换 for 循环来完成所有这一切
filtered_output = output + self.wiener_filter(t_map)

Torch 知道 wiener_filter 模块是成员模块,因此如果您打印 AutoEncoder 的模块,它会列出该模块。

如果你想并行化你的维纳滤波器,你需要用 PyTorch 的术语来做,这意味着使用它对张量的操作。这些操作以并行方式实现。