Pytorch 展开 class 标签张量

Pytorch expand class label tensor

我正在 pytorch 中进行语义分割项目,我有以下形状的 class 映射:[H,W] 其中每个元素都是 0-n 之间的整数,其中 n 是class是的,H是图片的高度,W是图片的宽度。

这是一个例子:

test_label = torch.zeros([10,10])
test_label[:5,:5] = 1
test_label[5:,:5] = 2
test_label[:5,5:] = 3
test_label

输出:

tensor([[1., 1., 1., 1., 1., 3., 3., 3., 3., 3.],
    [1., 1., 1., 1., 1., 3., 3., 3., 3., 3.],
    [1., 1., 1., 1., 1., 3., 3., 3., 3., 3.],
    [1., 1., 1., 1., 1., 3., 3., 3., 3., 3.],
    [1., 1., 1., 1., 1., 3., 3., 3., 3., 3.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.]])

现在,我想要的是形状为 [n,C,H] 的东西,其中 [1,C,H] 将是例如:

tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
    [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
    [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
    [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
    [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

而 [2,H,W] 将是:

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.],
    [2., 2., 2., 2., 2., 0., 0., 0., 0., 0.]])

是否有一个 pytorch 函数可以做到这一点?我当前的方法是迭代屏蔽原始张量中的每个唯一元素,并将它们插入到形状为 [n,H,W] 的张量中,该张量最初填充为全零。但这似乎不是最好的方法。我试图查找它,但似乎无法为该操作找到正确的名称。

非常感谢您的宝贵时间。

您可以应用 nn.functional.one_hot 将密集格式转换为 one-hot 编码,然后与标签值相乘以获得所需的结果:

>>> C = int(x.max()) + 1
>>> ohe = F.one_hot(x.long(), num_classes=C)

然后乘以标签值:

>>> res = ohe*torch.arange(C)
>>> res.permute(2,0,1)
tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2, 0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2, 0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2, 0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2, 0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])