如何在不使用 for 循环的情况下反转 2D Pytorch 张量的每一行中的排列?

How to invert the permutations in each row of a 2D Pytorch tensor without using for loop?

假设我有一个二维张量,每行包含一个排列,例如

a = torch.tensor([[4, 3, 5, 0, 2, 1],
                  [5, 0, 3, 4, 2, 1],
                  [3, 1, 0, 2, 4, 5],
                  [5, 0, 4, 3, 2, 1],
                  [2, 4, 0, 1, 5, 3]])

我想反转张量 a 中的所有排列以获得张量 b。例如,在对上述张量 a 执行此操作后,我想要的输出应该是

>>> b
tensor([[3, 5, 4, 1, 0, 2],
        [1, 5, 4, 2, 3, 0],
        [2, 1, 3, 0, 4, 5],
        [1, 5, 4, 3, 2, 0],
        [2, 3, 0, 5, 1, 4]])

我尝试在线搜索并找到了 答案。我应该如何概括该答案中提到的方法以在不使用 for 循环的情况下反转所有排列?

可以使用torch.Tensor.scatter_

来完成
b = torch.zeros_like(a)
b.scatter_(1, a, torch.arange(a.size(1)).expand(a.size(0),-1))