如何将 PyTorch 张量的每一行中的重复值清零?

How can I zero out duplicate values in each row of a PyTorch tensor?

我想编写一个函数来实现 中描述的行为。

也就是说,我想将 PyTorch 中矩阵每一行中的重复值清零。例如,给定一个矩阵

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
              [1, 6, 3, 5, 3, 5, 4]])

我想得到

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 0, 0, 4]])

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 4, 0, 0]])

根据链接的问题,仅 torch.unique() 是不够的。我想知道如何在没有循环的情况下实现这个功能。

x = torch.tensor([
    [1, 2, 3, 4, 3, 3, 4],
    [1, 6, 3, 5, 3, 5, 4]
], dtype=torch.long)

# sorting the rows so that duplicate values appear together
# e.g., first row: [1, 2, 3, 3, 3, 4, 4]
y, indices = x.sort(dim=-1)

# subtracting, so duplicate values will become 0
# e.g., first row: [1, 2, 3, 0, 0, 4, 0]
y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()

# retrieving the original indices of elements
indices = indices.sort(dim=-1)[1]

# re-organizing the rows following original order
# e.g., first row: [1, 2, 3, 4, 0, 0, 0]
result = torch.gather(y, 1, indices)

print(result) # => output

输出

tensor([[1, 2, 3, 4, 0, 0, 0],
        [1, 6, 3, 5, 0, 0, 4]])