如何获得独特的元素及其首次出现的 pytorch 张量索引?

How to get unique elements and their firstly appeared indices of a pytorch tensor?

假设一个 2*X(总是 2 行)pytorch 张量:

A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
            [43., 33., 43., 76., 33., 76., 55., 55., 55.]])

torch.unique(A, dim=1) 将 return:

tensor([[ 1.,  2.,  2.,  3.,  3.,  4.],
        [43., 33., 43., 33., 76., 55.]])

但我还需要每个唯一元素在原始输入中首次出现的位置的索引。在这种情况下,索引应该是这样的:

tensor([0, 1, 2, 3, 4, 6])

# Explanation
# A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
#             [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
#              (0)  (1)  (2)  (3)  (4)       (6) 

这对我来说很复杂,因为张量的第二行 A 可能没有很好地排序:

A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
            [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
                             ^         ^

是否有一种简单有效的方法来获取所需的索引?

P.S。张量的第一行始终按升序排列可能很有用。

获得此类指标的一种可能方式:

unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]

对于上面代码段中的张量 A

print(first_indicies)
# tensor([0, 1, 2, 4, 3, 6])

请注意,在这种情况下 unique 等于:

 tensor([[ 1.,  2.,  2.,  3.,  3.,  4.],
         [43., 33., 43., 33., 76., 55.]])