Keras - 带权重的多标签分类

Keras - Multilabel classification with weights

我正在尝试 class验证一些每个样本有多个标签的 CXR 图像。据我了解,我必须放置一个带有 sigmoid 激活的致密层,并使用二元交叉熵作为我的损失函数。问题是存在很大的 class 不平衡(正常比异常多得多)。我很好奇这是我的模型 sofar:

from keras_applications.resnet_v2 import ResNet50V2
from keras.layers import GlobalAveragePooling2D, Dense
from keras import Sequential
ResNet = Sequential()
ResNet.add(ResNet50V2(input_shape=shape, include_top=False, weights=None,backend=keras.backend,
    layers=keras.layers,
    models=keras.models,
    utils=keras.utils))
ResNet.add(GlobalAveragePooling2D(name='avg_pool'))

ResNet.add(Dense(len(label_counts), activation='sigmoid', name='Final_output'))

正如我们所见,我正在使用 sigmoid 来获得输出,但我对如何实现权重有点困惑。我想我需要使用使用 BCE(use_logits = true) 的自定义损失函数。像这样:

xent = tf.losses.BinaryCrossEntropy(
    from_logits=True,
    reduction=tf.keras.losses.Reduction.NONE)
loss = tf.reduce_mean(xent(targets, pred) * weights))

所以它将输出视为对数,但我不确定的是最终输出的激活。我是用 sigmoid 的激活来保持它,还是使用线性激活(未激活)?我假设我们保留了 sigmoid,并将其视为 logit,但我不确定,因为 pytorches“torch.nn.BCEWithLogitsLoss”包含一个 sigmoid 层

编辑:找到这个:https://www.reddit.com/r/tensorflow/comments/dflsgv/binary_cross_entropy_with_from_logits_true/

根据:pgaleone

from_logits=True means that the loss function expects a linear tensor (the output layer of your network without any activation function but the identity), so you have to remove the sigmoid, since it will be the loss function itself to apply the softmax to your network output, and then to compute the cross-entropy

您实际上不想在多标签 class化中使用 from_logits

来自文档 [1]:

logits: Per-label activations, typically a linear output. These activation energies are interpreted as unnormalized log probabilities.

所以你说得对,当激活函数设置为 True 时你不想使用它。

但是,文档还说

WARNING: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results

Softmax 根据定义优化一个 class。这就是 softmax 设计的工作原理。由于您正在进行多标签 class化,因此您应该使用 sigmoid,正如您自己提到的那样。

这意味着如果你想使用 sigmoid,你不能使用 from_logits 因为它会在 sigmoid 之后应用 softmax,这通常不是你想要的。

解决方法是删除这一行:

from_logits=True,

[1] https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits?version=stable