如何根据索引张量对单热张量进行排序

How to sort a one hot tensor according to a tensor of indices

给定以下张量:

tensor = torch.Tensor([[1., 0., 0., 0., 0.],
                       [0., 1., 0., 0., 0.],
                       [0., 0., 1., 0., 0.],
                       [0., 0., 0., 0., 1.],
                       [1., 0., 0., 0., 0.],
                       [1., 0., 0., 0., 0.],
                       [0., 0., 0., 1., 0.],
                       [0., 0., 0., 0., 1.]])

以下是包含索引的张量:

indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1])  

如何使用 indices 中的值对 tensor 进行排序?

尝试使用 sorted 给出错误:

TypeError: 'Tensor' object is not callable`.

numpy.sort 给出:

ValueError: Cannot specify order when the array has no fields.`

您可以像这样使用索引:

tensor = torch.Tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1]) 
sorted_tensor = tensor[indices]
print(sorted_tensor)
# output
tensor([[0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.]])