在 pytorch 中用于图像分割的通道明智的 CrossEntropyLoss

Channel wise CrossEntropyLoss for image segmentation in pytorch

我正在做图像分割任务。总共有 7 个 classes,所以最终输出是一个像 [batch, 7, height, width] 这样的张量,它是一个 softmax 输出。现在直觉上我想使用 CrossEntropy 损失,但是 pytorch 实现不适用于通道明智的单热编码向量

所以我打算自己做一个功能。在一些 Whosebug 的帮助下,我的代码到目前为止看起来像这样

from torch.autograd import Variable
import torch
import torch.nn.functional as F


def cross_entropy2d(input, target, weight=None, size_average=True):
    # input: (n, c, w, z), target: (n, w, z)
    n, c, w, z = input.size()
    # log_p: (n, c, w, z)
    log_p = F.log_softmax(input, dim=1)
    # log_p: (n*w*z, c)
    log_p = log_p.permute(0, 3, 2, 1).contiguous().view(-1, c)  # make class dimension last dimension
    log_p = log_p[
       target.view(n, w, z, 1).repeat(0, 0, 0, c) >= 0]  # this looks wrong -> Should rather be a one-hot vector
    log_p = log_p.view(-1, c)
    # target: (n*w*z,)
    mask = target >= 0
    target = target[mask]
    loss = F.nll_loss(log_p, target.view(-1), weight=weight, size_average=False)
    if size_average:
        loss /= mask.data.sum()
    return loss


images = Variable(torch.randn(5, 3, 4, 4))
labels = Variable(torch.LongTensor(5, 3, 4, 4).random_(3))
cross_entropy2d(images, labels)

我有两个错误。代码本身提到了一个,它需要 one-hot vector。第二个说以下

RuntimeError: invalid argument 2: size '[5 x 4 x 4 x 1]' is invalid for input with 3840 elements at ..\src\TH\THStorage.c:41

例如,我试图让它在 3 class 问题上起作用。所以目标和标签是(为了简化不包括批处理参数!)

目标:

 Channel 1     Channel 2  Channel 3

[[0 1 1 0 ] [0 0 0 1 ] [1 0 0 0 ] [0 0 1 1 ] [0 0 0 0 ] [1 1 0 0 ] [0 0 0 1 ] [0 0 0 0 ] [1 1 1 0 ] [0 0 0 0 ] [0 0 0 1 ] [1 1 1 0 ]

标签:

 Channel 1     Channel 2  Channel 3

[[0 1 1 0 ] [0 0 0 1 ] [1 0 0 0 ] [0 0 1 1 ] [.2 0 0 0] [.8 1 0 0 ] [0 0 0 1 ] [0 0 0 0 ] [1 1 1 0 ] [0 0 0 0 ] [0 0 0 1 ] [1 1 1 0 ]

那么我该如何修改我的代码来计算通道方面的交叉熵损失?

2D(或 KD)交叉熵是 NN 中非常基本的构建块。 pytorch 不太可能没有 "out-of-the-box" 实现。
查看torch.nn.CrossEntropyLoss and the underlying torch.nn.functional.cross_entropy你会发现损失可以处理2D输入(即4D输入预测张量)。
此外,您可以检查实际实现此 here 的代码,并查看它如何根据 input 张量的 dim 处理不同的情况。

所以,不用麻烦,已经为您完成了!

正如 Shai 的回答所述,torch.nn.CrossEntropy() 函数的文档可以在 here and the code can be found here 中找到。内置函数确实已经支持 KD 交叉熵损失。

在 3D 情况下,torch.nn.CrossEntropy() 函数需要两个参数:4D 输入矩阵和 3D 目标矩阵。输入矩阵的形状为:(Minibatch, 类, H, W)。目标矩阵的形状为 (Minibatch, H, W),数字范围为 0 到 (类-1)。如果你从一个单热编码矩阵开始,你将不得不用 np.argmax().

转换它

具有三个 类 且小批量大小为 1 的示例:

import pytorch
import numpy as np

input_torch = torch.randn(1, 3, 2, 5, requires_grad=True)

one_hot = np.array([[[1, 1, 1, 0, 0], [0, 0, 0, 0, 0]],    
                    [[0, 0, 0, 0, 0], [1, 1, 1, 0, 0]],
                    [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]]])

target = np.array([np.argmax(a, axis = 0) for a in target])
target_torch = torch.tensor(target_argmax)

loss = torch.nn.CrossEntropyLoss()
output = loss(input_torch, target_torch)
output.backward()