Pytorch 中的多标签 class 化 class 不平衡
Multilabel classification with class imbalance in Pytorch
我有一个多标签 class化问题,我正试图用 Pytorch 中的 CNN 解决这个问题。我有 80,000 个训练示例和 7900 classes;每个示例可以同时属于多个 classes,每个示例的 classes 的平均数量是 130。
问题是我的数据集很不平衡。对于某些 classes,我只有 ~900 个示例,大约是 1%。对于“代表过多”的 classes,我有大约 12000 个示例 (15%)。当我训练模型时,我使用来自 pytorch 的 BCEWithLogitsLoss 和一个正权重参数。我按照文档中描述的相同方式计算权重:负示例的数量除以正示例的数量。
因此,我的模型几乎高估了每个 class… 或次要和主要 classes 我得到的预测几乎是真实标签的两倍。而我的 AUPRC 仅为 0.18。尽管它比根本没有加权要好得多,因为在这种情况下,模型预测一切为零。
所以我的问题是,如何提高性能?还有什么我可以做的吗?我尝试了不同的批量采样技术(对少数人进行过采样 class),但它们似乎不起作用。
我会建议其中一种策略
焦点丢失
在
中介绍了一种通过调整损失函数来处理不平衡训练数据的非常有趣的方法
Tsung-Yi Lin、Priya Goyal、Ross Girshick、Kaiming He 和 Piotr DollarFocal Loss for Dense Object Detection(ICCV 2017)。
他们建议以一种方式修改二元交叉熵损失,以减少易于分类的示例的损失和梯度,同时“将精力集中在”模型出现严重错误的示例上。
硬负挖掘
另一种流行的方法是进行“hard negative mining”;也就是说,仅为部分训练示例传播梯度 - “硬”示例。
参见,例如:
Abhinav Shrivastava、Abhinav Gupta 和 Ross GirshickTraining Region-based Object Detectors with Online Hard Example Mining(CVPR 2016)
@Shai 提供了两种在深度学习时代发展起来的策略。我想为您提供一些额外的传统机器学习选项:过采样和欠采样。
它们的主要思想是在开始训练之前通过采样生成更平衡的数据集。请注意,您可能会遇到一些问题,例如丢失数据多样性(欠采样)和过度拟合训练数据(过采样),但这可能是一个很好的起点。
有关详细信息,请参阅 wiki link。
我有一个多标签 class化问题,我正试图用 Pytorch 中的 CNN 解决这个问题。我有 80,000 个训练示例和 7900 classes;每个示例可以同时属于多个 classes,每个示例的 classes 的平均数量是 130。
问题是我的数据集很不平衡。对于某些 classes,我只有 ~900 个示例,大约是 1%。对于“代表过多”的 classes,我有大约 12000 个示例 (15%)。当我训练模型时,我使用来自 pytorch 的 BCEWithLogitsLoss 和一个正权重参数。我按照文档中描述的相同方式计算权重:负示例的数量除以正示例的数量。
因此,我的模型几乎高估了每个 class… 或次要和主要 classes 我得到的预测几乎是真实标签的两倍。而我的 AUPRC 仅为 0.18。尽管它比根本没有加权要好得多,因为在这种情况下,模型预测一切为零。
所以我的问题是,如何提高性能?还有什么我可以做的吗?我尝试了不同的批量采样技术(对少数人进行过采样 class),但它们似乎不起作用。
我会建议其中一种策略
焦点丢失
在
中介绍了一种通过调整损失函数来处理不平衡训练数据的非常有趣的方法
Tsung-Yi Lin、Priya Goyal、Ross Girshick、Kaiming He 和 Piotr DollarFocal Loss for Dense Object Detection(ICCV 2017)。
他们建议以一种方式修改二元交叉熵损失,以减少易于分类的示例的损失和梯度,同时“将精力集中在”模型出现严重错误的示例上。
硬负挖掘
另一种流行的方法是进行“hard negative mining”;也就是说,仅为部分训练示例传播梯度 - “硬”示例。
参见,例如:
Abhinav Shrivastava、Abhinav Gupta 和 Ross GirshickTraining Region-based Object Detectors with Online Hard Example Mining(CVPR 2016)
@Shai 提供了两种在深度学习时代发展起来的策略。我想为您提供一些额外的传统机器学习选项:过采样和欠采样。
它们的主要思想是在开始训练之前通过采样生成更平衡的数据集。请注意,您可能会遇到一些问题,例如丢失数据多样性(欠采样)和过度拟合训练数据(过采样),但这可能是一个很好的起点。
有关详细信息,请参阅 wiki link。