如何在火炬张量中交换两行?
How two rows can be swapped in a torch tensor?
var = [[0, 1, -4, 8],
[2, -3, 2, 1],
[5, -8, 7, 1]]
var = torch.Tensor(var)
这里,var
是一个 3 x 4 (2d) 张量。如何交换第一行和第二行以获得以下二维张量?
2, -3, 2, 1
0, 1, -4, 8
5, -8, 7, 1
生成你想要的排列索引:
index = torch.LongTensor([1,0,2])
应用排列:
var[index] = var
不起作用,因为某些维度在复制之前会被覆盖:
>>> var = [[0, 1, -4, 8],
[2, -3, 2, 1],
[5, -8, 7, 1]]
>>> x = torch.tensor(var)
>>> index = torch.LongTensor([1, 0, 2])
>>> x[index] = x
>>> x
tensor([[ 0, 1, -4, 8],
[ 0, 1, -4, 8],
[ 5, -8, 7, 1]])
对我来说,创建一个新的张量(具有单独的底层存储)来保存结果就足够了:
>>> x = torch.tensor(var)
>>> index = torch.LongTensor([1, 0, 2])
>>> y = torch.zeros_like(x)
>>> y[index] = x
或者,您可以使用 index_copy_
(following this explanation in discuss.pytorch.org),尽管目前我看不出这两种方式有什么优势。
正如其他答案所建议的那样,您的排列索引本身应该是张量,但这不是必需的。您可以像这样交换第一行和第二行:
>>> var
tensor([[ 0, 1, -4, 8],
[ 2, -3, 2, 1],
[ 5, -8, 7, 1]])
>>> var[[0, 1]] = var[[1, 0]]
>>> var
tensor([[ 2, -3, 2, 1],
[ 0, 1, -4, 8],
[ 5, -8, 7, 1]])
var
可以是 NumPy 数组或 PyTorch 张量。
您可以为此使用 index_select
:
>>> idx = torch.LongTensor([1,0,2])
>>> var.index_select(0, idx)
tensor([[ 2, -3, 2, 1],
[ 0, 1, -4, 8],
[ 5, -8, 7, 1]])
var = [[0, 1, -4, 8],
[2, -3, 2, 1],
[5, -8, 7, 1]]
var = torch.Tensor(var)
这里,var
是一个 3 x 4 (2d) 张量。如何交换第一行和第二行以获得以下二维张量?
2, -3, 2, 1
0, 1, -4, 8
5, -8, 7, 1
生成你想要的排列索引:
index = torch.LongTensor([1,0,2])
应用排列:
var[index] = var
>>> var = [[0, 1, -4, 8],
[2, -3, 2, 1],
[5, -8, 7, 1]]
>>> x = torch.tensor(var)
>>> index = torch.LongTensor([1, 0, 2])
>>> x[index] = x
>>> x
tensor([[ 0, 1, -4, 8],
[ 0, 1, -4, 8],
[ 5, -8, 7, 1]])
对我来说,创建一个新的张量(具有单独的底层存储)来保存结果就足够了:
>>> x = torch.tensor(var)
>>> index = torch.LongTensor([1, 0, 2])
>>> y = torch.zeros_like(x)
>>> y[index] = x
或者,您可以使用 index_copy_
(following this explanation in discuss.pytorch.org),尽管目前我看不出这两种方式有什么优势。
正如其他答案所建议的那样,您的排列索引本身应该是张量,但这不是必需的。您可以像这样交换第一行和第二行:
>>> var
tensor([[ 0, 1, -4, 8],
[ 2, -3, 2, 1],
[ 5, -8, 7, 1]])
>>> var[[0, 1]] = var[[1, 0]]
>>> var
tensor([[ 2, -3, 2, 1],
[ 0, 1, -4, 8],
[ 5, -8, 7, 1]])
var
可以是 NumPy 数组或 PyTorch 张量。
您可以为此使用 index_select
:
>>> idx = torch.LongTensor([1,0,2])
>>> var.index_select(0, idx)
tensor([[ 2, -3, 2, 1],
[ 0, 1, -4, 8],
[ 5, -8, 7, 1]])