为直方图箱的值创建一次性编码

create one-hot encoding for values of histogram bins

给定以下大小的张量 torch.Size([22])

tensor([-20.1659, -19.7022, -17.4124, -16.7115, -16.4696, -15.6848, -15.5201, -14.5384, -12.5017, -12.4227, -11.0946, -10.7844, -10.5467,  -9.3933,  -4.2351,  -4.0521,  -3.8844, -3.8668,  -3.7337,  -3.7002,  -3.6242,  -3.5820])  

和下面的直方图:

hist = torch.histogram(tensor, 5)
hist
torch.return_types.histogram(
hist=tensor([3., 5., 5., 1., 8.]),
bin_edges=tensor([-20.1659, -16.8491, -13.5323, -10.2156,  -6.8988,  -3.5820]))

对于张量的每个值,如何创建一个对应其bin编号的one hot编码,以便输出大小为torch.Size([22, 5])

的张量

您可以使用torch.repeat_interleave

import torch

bins = torch.tensor([3, 5, 5, 1, 8])
one_hots = torch.eye(len(bins))
one_hots = torch.repeat_interleave(one_hots, bins, dim=0)
print(one_hots)
输出
tensor([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.]])