如何获得独特的元素及其首次出现的 pytorch 张量索引?
How to get unique elements and their firstly appeared indices of a pytorch tensor?
假设一个 2*X(总是 2 行)pytorch 张量:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
torch.unique(A, dim=1)
将 return:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
但我还需要每个唯一元素在原始输入中首次出现的位置的索引。在这种情况下,索引应该是这样的:
tensor([0, 1, 2, 3, 4, 6])
# Explanation
# A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
# [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
# (0) (1) (2) (3) (4) (6)
这对我来说很复杂,因为张量的第二行 A
可能没有很好地排序:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
^ ^
是否有一种简单有效的方法来获取所需的索引?
P.S。张量的第一行始终按升序排列可能很有用。
获得此类指标的一种可能方式:
unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]
对于上面代码段中的张量 A
:
print(first_indicies)
# tensor([0, 1, 2, 4, 3, 6])
请注意,在这种情况下 unique
等于:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
假设一个 2*X(总是 2 行)pytorch 张量:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
torch.unique(A, dim=1)
将 return:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
但我还需要每个唯一元素在原始输入中首次出现的位置的索引。在这种情况下,索引应该是这样的:
tensor([0, 1, 2, 3, 4, 6])
# Explanation
# A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
# [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
# (0) (1) (2) (3) (4) (6)
这对我来说很复杂,因为张量的第二行 A
可能没有很好地排序:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
^ ^
是否有一种简单有效的方法来获取所需的索引?
P.S。张量的第一行始终按升序排列可能很有用。
获得此类指标的一种可能方式:
unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]
对于上面代码段中的张量 A
:
print(first_indicies)
# tensor([0, 1, 2, 4, 3, 6])
请注意,在这种情况下 unique
等于:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])