如何在 Pytorch 中定义“不关心”class?

How to define a “don't care” class in Pytorch?

我有一个时间序列 classification 任务,其中我应该为每个时间戳 t.[=12 输出 3 classes 的 classification =]

所有数据都按帧标记。

数据集中有3个以上class[也是不平衡的]。

我的网络应该按顺序查看所有样本,因为它使用它来获取历史信息。
因此,我不能只在预处理时消除所有不相关的 class 个样本。

如果预测帧的标记与这 3 个 class 不同,我不关心结果。


如何在 Pytorch 中正确执行此操作?

来自 this discussion, which was not google searchable, there are two options, both are options of the CrossEntropyLoss 的关注:

选项 1

如果只有一个class可以忽略,在实例化损失时使用ignore_index=class_index

选项 2

如果有更多class,则使用weight=weights,与weights.shape==n_classestorch.sum(weights[ignored_classes]) == 0