如何根据输出张量从pytorch模型中删除预测头?

How to remove a prediction head from pytorch model based on the output tensor?

我正在从事与 ViT(Vision Transformer)相关的项目,一些低级定义在 timm 库的深处,我无法更改。低级库定义涉及线性分类预测头,它不是我网络的一部分。

在我切换到 DDP 并行实现之前,一切都很好。 Pytorch 抱怨一些参数对损失没有贡献,它指示我使用“find_unused_parameters=True”。事实上,这是一个常见的场景,如果我将这个“find_unused_parameters=True”添加到训练例程中,它会再次起作用。但是,我只允许更改我们代码库中的模型定义,但我不能修改与训练相关的任何内容……

所以我想我现在唯一能做的就是从模型中“移除”线性头。 虽然我无法深入研究 ViT 的低级定义,但我可以像这样输出这个张量:

encoder_output,   linear_head_output =  ViT(input)

是否可以根据这个linear_head_output张量去掉这个线性预测头?

只需在调用 timm.create_model() 创建 ViT 模型时设置 num_classes=0

这是来自 TIMM documentation on Feature Extraction 的示例:

import torch
import timm
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
o = m(torch.randn(2, 3, 224, 224))
print(f'Unpooled shape: {o.shape}')