实现多类骰子损失函数

Implementing Multiclass Dice Loss Function

我正在使用 UNet 进行多重 class 分割。我对模型的输入是 HxWxC,输出是

outputs = layers.Conv2D(n_classes, (1, 1), activation='sigmoid')(decoder0)

使用 SparseCategoricalCrossentropy 我可以很好地训练网络。现在我也想尝试骰子系数作为损失函数。实现如下,

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.math.sigmoid(y_pred)

    numerator = 2 * tf.reduce_sum(y_true * y_pred) + smooth
    denominator = tf.reduce_sum(y_true + y_pred) + smooth

    return 1 - numerator / denominator

然而,我实际上得到的是增加的损失而不是减少的损失。我已经检查了多个来源,但我发现所有 material 都使用二进制 class 化的骰子损失,而不是 multiclass。所以我的问题是执行有问题。

问题是你的骰子损失没有解决你拥有的 类 的数量,而是假设二进制情况,所以它可能解释你损失的增加。

您应该实施广义骰子损失,以解释所有 类 和 return 的价值。

类似于以下内容:

def dice_coef_9cat(y_true, y_pred, smooth=1e-7):
    '''
    Dice coefficient for 10 categories. Ignores background pixel label 0
    Pass to model as metric during compile statement
    '''
    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=10)[...,1:])
    y_pred_f = K.flatten(y_pred[...,1:])
    intersect = K.sum(y_true_f * y_pred_f, axis=-1)
    denom = K.sum(y_true_f + y_pred_f, axis=-1)
    return K.mean((2. * intersect / (denom + smooth)))

def dice_coef_9cat_loss(y_true, y_pred):
    '''
    Dice loss to minimize. Pass to model as loss during compile statement
    '''
    return 1 - dice_coef_9cat(y_true, y_pred)

此片段摘自https://github.com/keras-team/keras/issues/9395#issuecomment-370971561

这是 9 个类别,您应该根据自己的类别数进行调整。

如果要进行多 class 分割,则应使用 'softmax' 激活函数。

我建议使用单热编码的真实掩码。这需要在损失计算代码之外完成。

广义骰子损失和其他实现如下link:

https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py

不知道为什么,但最后一层有“sigmoid”作为激活函数。 对于 Multiclass 分割,它必须是“softmax”而不是“sigmoid”。

此外,您正在考虑的损失是 SparseCategoricalCrossentropy 以及多通道输出。如果最后一层只有一个通道(在进行多 class 分割时),那么使用 SparseCategoricalCrossentropy 是有意义的,但是当你有多个通道作为输出时,要考虑的损失是“CategoricalCrossentropy”。

由于激活和输出通道不匹配(如上所述),您的损失正在增加。

改变

outputs = layers.Conv2D(n_classes, (1, 1), activation='sigmoid')(decoder0)

outputs = layers.Conv2D(n_classes, (1, 1), activation='softmax')(decoder0)