在 PyTorch 的转换层中修剪过滤器后如何更新预训练模型?

How to update a pretrained model after Pruning of filters in its conv layer in PyTorch?

我有一个从头开始定义的预训练模型 LeNet5。我正在对下面显示的模型中存在的卷积层中的过滤器执行 p运行ing。

class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        self.feature_extractor = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=800, out_features=500),
            nn.ReLU(),
            nn.Linear(in_features=500, out_features=10), # 10 - possible classes
        )
    
    def forward(self, x):
        #x = x.view(x.size(0), -1) 
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=1)
        return logits, probs

我已经成功地从第 1 层的 20 个过滤器中移除了 2 个(现在 conv2d layer1 中的 18 个过滤器)和第 2 层中的 50 个过滤器中的 5 个过滤器(现在 conv2d layer3 中的 45 个过滤器)。所以,现在我需要使用如下所做的更改来更新模型 -

但是,我无法 运行 模型,因为它给出了尺寸错误。

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x720 and 800x500)

如何更新号码。使用 Pytorch 执行 p运行ing 的模型中存在的过滤器层数?有没有我可以使用的图书馆?

假设您不希望模型在运行时自动更改结构,您可以通过简单地更改构造函数的输入参数来轻松更新模型的结构。例如:

nn.Conv2d(in_channels = 1, out_channels = 18, kernel_size = 5, stride = 1),
nn.Conv2d(in_channels = 18, out_channels = 45, kernel_size = 5, stride = 1),

等等。

如果每次更改模型结构时都从头开始重新训练,这就是您需要做的全部。但是,如果您想在更改模型时保留部分已学习的参数,则需要 select 这些相关值并将它们重新分配给模型参数。例如,考虑与第一个卷积层相关的参数,1 个输入,20 个输出,内核大小为 5。该层的权重和偏差大小为 [1,20,5,5][1,20]。您需要修改这些参数,使它们的大小为 [1,18,5,5][1,18]。因此,您需要针对要维护的特定 kernels/filters 以及要修剪的内核的索引。这样做的代码语法大致是:

params = net.state_dict()
params["feature_extractor"]["conv1.weight"] = params["feature_extractor"]["conv1.weight"][:,:18,:,:]
params["feature_extractor"]["conv1.bias"] = params["feature_extractor"]["conv1.bias"][:,:18]
# and so on for the other layers

net.load_state_dict(params)

在这里,我只是删除了第一个卷积层的最后两个 kernels/bias 值。 (请注意,实际的字典键名称可能略有不同;我没有编写代码来检查,因为如上面的评论所示,您包含了代码图片而不是真实的、可复制的代码,因此请尝试执行后者在未来。)