如何在 PyTorch 的复杂(嵌套)模块中有效地初始化(并检查完整性)层的权重?

How to initialise (and check sanity) weights efficiently of layers within complex (nested) modules in PyTorch?

寻找访问嵌套模块和层以设置权重的有效方法

我正在复制 DCGAN Paper,我的代码按预期工作。我发现在论文中,作者说:

All weights were initialized from a zero-centered Normal distribution with standard deviation 0.02

可以使用 torch.nn.init.normal_(nn.Conv2d(1,1,1, 1,1 ).weight.data, 0.0, 0.02) 来完成,但我使用 ModuleList 和其他人有复杂的结构。最有效的方法是什么?

复杂,我的实现请看下面代码:

'''
Implement the Deep Convolution Gan AKA DCGAN in Pytorch: Paper at https://arxiv.org/pdf/1511.06434v2.pdf
'''
import torch
import torch.nn as nn


class GeneratorBlock(nn.Module):
    '''
    Generator Block uses TransposedConv2D -> Batch Norm (except LAST block) -> Relu
    Note: kernel_size = 4, stride = 2, padding = 1 is used in the paper. When BatchNorm is used, Bias is not used for Conv2D
    '''
    def __init__(self, in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, use_batchnorm:bool = True):
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.transpose_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = kernel_size, stride=stride, padding=padding, bias = not self.use_batchnorm)
        self.batch_norm = nn.BatchNorm2d(out_channels) if self.use_batchnorm else None
        self.activation = nn.ReLU() # Paper uses Relu in Generator Network
    
    def forward(self, x):
        x = self.transpose_conv(x)
        return self.activation(self.batch_norm(x)) if self.use_batchnorm else self.activation(x)


class Generator(nn.Module):
    '''
    Generate Images using Transposed Convolution. Input is a random noise of [Batch, 100, 1,1] Dimension and then upsampled
    '''
    def __init__(self, input_features = 100, base_feature = 128, final_channels:int = 1):
        '''
        We use nn.Sequantial here just to show the workings. If you want to make the layers dynamically using a loop, find nn.ModuleList() in the Descriminator block. Both works same
        So we'll use 'base_feature = 64' as a base for input and output channels
        args:
            input_features: The shape of Random Noise from which an image will be generated
            base_feature: The shape of feature map or number or channels which will act as out base. Other inputs and outputs will be calculated based on this
            final_channels: The channels / features which will be sent to the Discriminator as an input
        '''
        super(Generator, self).__init__()

        # in Descriminator, we do the same work using ModuleList(). Uses 4 blocks
        self.blocks = nn.Sequential(
            GeneratorBlock(in_channels = input_features, out_channels = base_feature * 8, stride = 1, padding = 0), # from Random Noise, Generate 1024 features
            GeneratorBlock(in_channels = base_feature * 8, out_channels = base_feature * 4), # 1024 -> 512 features
            GeneratorBlock(in_channels = base_feature * 4, out_channels = base_feature * 2), # 512 -> 256 features
            GeneratorBlock(in_channels = base_feature * 2, out_channels = base_feature), # 256 -> 128 features
            nn.ConvTranspose2d(base_feature, final_channels, kernel_size = 4, stride = 2, padding = 1)# 128 -> final feature. It is just GeneratorBlock without ReLu and BatchNorm ;)
        )
        self.activation = nn.Tanh() # To make the outputs between [-1,1]
    
    def forward(self, x):
        '''
        Takes Random Noise as input and Generte features from that
        '''
        return self.activation(self.blocks(x))
    

class DiscriminatorBlock(nn.Module):
    '''
    Discriminator Block uses Conv2D -> Batch Norm (except FIRST block) -> LeakyRelu
    Note: kernel_size = 4, stride = 2, padding = 1 is used in the paper. When BatchNorm is used, Bias is not used for Conv2D
    '''
    def __init__(self, in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, use_batchnorm:bool = True):
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias = not self.use_batchnorm)
        self.batch_norm = nn.BatchNorm2d(out_channels) if self.use_batchnorm else None
        self.activation = nn.LeakyReLU(0.2)
    
    def forward(self, x):
        x = self.conv(x)
        return self.activation(self.batch_norm(x)) if self.use_batchnorm else self.activation(x)
    

class Discriminator(nn.Module):
    '''
    CNNs to classify whether the image generated by the Generator are as good as the real ones
    Feature Changes as :: 1 -> 64 -> 128 -> 256 -> 512 -> 1
    '''
    def __init__(self, input_features = 1, output_features = 1,  middle_features = [64,128,256]):
        '''
        In the paper, they take in a feature of [Batch, 1, 64, 64] from the Generator and then output a single number per sample in the batch
        '''
        super().__init__()
        self.layers = nn.ModuleList() # Just a fancy method of stacking layers using loop

        # in the paper, the first layer does not use BatchNorm
        self.layers.append(DiscriminatorBlock(input_features, middle_features[0], use_batchnorm = False)) #  1 -> 64 Because the input has 1 channel

        for i, channel in enumerate(middle_features): # total 4 blocks are used in paper. 1 has already been used in the line above. 3 blocks are these
            self.layers.append(DiscriminatorBlock(channel, channel*2)) # 64 -> 128 --- 128 -> 256 --- 256 -> 512

        self.final_conv = nn.Conv2d(in_channels = middle_features[-1]*2,  out_channels = output_features, kernel_size = 4, stride = 2,  padding = 0) # Input from previous layer 512 -> 1
        self.sigmoid_layer = nn.Sigmoid() # gives whether an image is real or fake or more precisely, how CLOSE is it to the real image

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        
        return self.sigmoid_layer(self.final_conv(x))


def test_DCGAN_code():
    noise = torch.rand(10,100,1,1)
    image = Generator()(noise)
    result = Discriminator()(image)
    print('Model Built Successfully!!! Generating 10 random samples and their end results')
    print(f"'Z' random Noise shape: {noise.shape} || Generator output shape: {image.shape} || Discriminator shape: {result.shape}")

您可以简单地遍历所有子模块,在您的 __init__ 方法结束时:

class Generator(nn.Module):
  def __init__(self, ....):
    # all code here
    # ...
    # init weights, at the very bottom of __init__
   for sm in self.modules():
     if isinstance(sm, nn.Conv2d):
       # only conv2d will be initialized in this way
       torch.nn.init.normal_(sm.weight.data, 0.0, 0.02)

完成。

找到了一些答案。只是想知道这是正确的方法:

def initialise_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

def check_sanity(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        print(m.weight.data.mean(), m.weight.data.std())


gen = Generator()
gen = gen.apply(initialise_weights)
gen = gen.apply(check_sanity)

接受的答案是最佳答案(另一种选择是 class _ConvNd 并修改源,换句话说替换 init.kaiming_uniform_(self.weight, a=math.sqrt(5)))。总而言之,最佳做法是定义另一个名为 reset_parameters() 的方法将其放在 __init__(self, *args) 的末尾并更改那里的参数:

class Generator(nn.Module):

    def __init__(self, *args) -> None:
        ...
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # comments
        for sm in self.modules():
            if isinstance(sm, nn.Conv2d):
                torch.nn.init.normal_(
                    sm.weight.data, 
                    mean=0.0, 
                    std=0.02
                    )