无法在小数据集上训练多标签分类器

Can’t train a multilabel classifier on a small dataset

我有 3 个 类 animalslandscapesbuildings 的数据集。每个数据集只有 100 张图像,我正在尝试使用 ResNet34 + fastai 训练分类器,但我 运行 遇到了几个问题。

第一个是我认为我的模型没有正确训练。当我使用 lr_finder 时,我的验证损失是 na,除非这是预期的:

然后我运行10个epochs。它看起来要么学习得很好,要么过度拟合得要命。

然后我解冻以找到另一个学习率,我的验证损失仍然是 na

现在,当我训练 10 个时期时,图形开始振荡,这告诉我它无法再学习了。准确性看起来不错(除非它过度拟合)所以我决定使用我的验证集来查看我的模型的实际表现:

因为我有300张图片,240张在train文件夹,60张在test文件夹。我的测试图像被标记是因为我想 运行 对其进行准确度评分。这是我 运行:

    path = '/content/dataset/'

    data_test = ImageList.from_folder(path).split_by_folder(train='train', valid='test').label_from_re(file_parse).transform(size=512).databunch().normalize(imagenet_stats)

    learn = cnn_learner(data, models.resnet50, metrics=[accuracy, top_1],callback_fns=ShowGraph)

    learn.load('stage-2')

这是我遇到的错误:

    RuntimeError                              Traceback (most recent call last)
    <ipython-input-28-f1232eb8cf47> in <module>()
      4 
      5 learn = cnn_learner(data, models.resnet50, metrics=[accuracy, top_1],callback_fns=ShowGraph)
    ----> 6 learn.load('stage-2')

    1 frames
    /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    828         if len(error_msgs) > 0:
    829             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    --> 830                                self.__class__.__name__, "\n\t".join(error_msgs)))
    831         return _IncompatibleKeys(missing_keys, unexpected_keys)
    832 

    RuntimeError: Error(s) in loading state_dict for Sequential:
        Missing key(s) in state_dict: "0.4.0.conv3.weight", "0.4.0.bn3.weight", "0.4.0.bn3.bias", "0.4.0.bn3.running_mean", "0.4.0.bn3.running_var", "0.4.0.downsample.0.weight", "0.4.0.downsample.1.weight", "0.4.0.downsample.1.bias", "0.4.0.downsample.1.running_mean", "0.4.0.downsample.1.running_var", "0.4.1.conv3.weight", "0.4.1.bn3.weight", "0.4.1.bn3.bias", "0.4.1.bn3.running_mean", "0.4.1.bn3.running_var", "0.4.2.conv3.weight", "0.4.2.bn3.weight", "0.4.2.bn3.bias", "0.4.2.bn3.running_mean", "0.4.2.bn3.running_var", "0.5.0.conv3.weight", "0.5.0.bn3.weight", "0.5.0.bn3.bias", "0.5.0.bn3.running_mean", "0.5.0.bn3.running_var", "0.5.1.conv3.weight", "0.5.1.bn3.weight", "0.5.1.bn3.bias", "0.5.1.bn3.running_mean", "0.5.1.bn3.running_var", "0.5.2.conv3.weight", "0.5.2.bn3.weight", "0.5.2.bn3.bias", "0.5.2.bn3.running_mean", "0.5.2.bn3.running_var", "0.5.3.conv3.weight", "0.5.3.bn3.weight", "0.5.3.bn3.bias", "0.5.3.bn3.running_mean", "0.5.3.bn3.running_var", "0.6.0.conv3.weight", "0.6.0.bn3.weight", "0.6.0.bn3.bias", "0.6.0.bn3.running_mean", "0.6.0.bn3.running_var", "0.6.1.conv3.weight", "0.6.1.bn3.weight", "0.6.1.bn3.bias", "0.6.1.bn3.running_mean", "0.6.1.bn3.running_var", "0.6.2.conv3.weight", "0.6.2.bn3.weight", "0.6.2.bn3.bias", "0.6.2.bn3.running_mean", "0.6.2.bn3.running_var", "0.6.3.conv3.weight", "0.6.3.bn3.weight", "0.6.3.bn3.bias", "0.6.3.bn3.running_mean", "0.6.3.bn3.running_var", "0.6.4.conv3.weight", "0.6.4.bn3.weight", "0.6.4.bn3.bias", "0.6.4.bn3.running_mean", "0.6....
        size mismatch for 0.4.0.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
        size mismatch for 0.4.1.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 256, 1, 1]).
        size mismatch for 0.4.2.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 256, 1, 1]).
        size mismatch for 0.5.0.conv1.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
        size mismatch for 0.5.0.downsample.0.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
        size mismatch for 0.5.0.downsample.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for 0.5.0.downsample.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for 0.5.0.downsample.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for 0.5.0.downsample.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for 0.5.1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
        size mismatch for 0.5.2.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
        size mismatch for 0.5.3.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
        size mismatch for 0.6.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
        size mismatch for 0.6.0.downsample.0.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 512, 1, 1]).
        size mismatch for 0.6.0.downsample.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
        size mismatch for 0.6.0.downsample.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
        size mismatch for 0.6.0.downsample.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
        size mismatch for 0.6.0.downsample.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
        size mismatch for 0.6.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
        size mismatch for 0.6.2.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
        size mismatch for 0.6.3.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
        size mismatch for 0.6.4.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
        size mismatch for 0.6.5.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).
        size mismatch for 0.7.0.conv1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 1, 1]).
        size mismatch for 0.7.0.downsample.0.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([2048, 1024, 1, 1]).
        size mismatch for 0.7.0.downsample.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
        size mismatch for 0.7.0.downsample.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
        size mismatch for 0.7.0.downsample.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
        size mismatch for 0.7.0.downsample.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).
        size mismatch for 0.7.1.conv1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 2048, 1, 1]).
        size mismatch for 0.7.2.conv1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 2048, 1, 1]).
        size mismatch for 1.2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([4096]).
        size mismatch for 1.2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([4096]).
        size mismatch for 1.2.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([4096]).
        size mismatch for 1.2.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([4096]).
        size mismatch for 1.4.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([512, 4096]).

您的方法几乎没有问题:

  1. 您提到您使用 resnet 34 进行训练,然后保存其权重。然而在测试时,您正在使用 resnet 50 并尝试加载 resnet 34 的权重,由于不同的拱门(层和参数),这将不起作用。

  2. 如果您在同一个 NB 中进行测试,您在创建训练和有效数据集时添加了测试数据集,并使用 2 行代码对整个数据集进行了预测。快速示例:

    test_imgs = (路径/'cars_test/').ls()

    data.add_test(test_imgs)

    learn.data = 数据

    preds = learn.get_preds(ds_type=DatasetType.Test)

  3. 如果你打算在不同的 nb 中使用它,请尝试 model.export 保存模型、它的权重、数据和所有信息。我认为杰里米在他的 NB 和课程中做到了这一点。