用另一个多维张量索引多维火炬张量
Index multidimensional torch tensor by another multidimensional tensor
我在 pytorch 中有一个张量 x 让我们说形状 (5,3,2,6) 和另一个张量 idx形状 (5,3,2,1),其中包含第一个张量中每个元素的索引。我想要用第二个张量的索引对第一个张量进行切片。我试过 x= x[idx] 但当我真的希望它的形状为 (5,3,2) 或 (5,3,2,1).
时,我得到了一个奇怪的维度
我将尝试举一个更简单的例子:
比方说
x=torch.Tensor([[10,20,30],
[8,4,43]])
idx = torch.Tensor([[0],
[2]])
我想要类似的东西
y = x[idx]
这样 'y' 输出 [[10],[43]]
或类似的东西。
索引表示最后一维所需元素的位置。对于上面的示例,其中 x.shape = (2,3) 最后一个维度是列,然后 'idx' 中的索引是列。我想要这个但是超过 2 个维度
据我从评论中了解到,您需要 idx
作为最后一个维度的索引,并且 idx
中的每个索引对应于 x
中的类似索引(除了最后一个维度)。那样的话(这是numpy的版本,你可以把它转换成torch):
ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]
输出:
[[10]
[43]]
可以使用range
;和 squeeze
以获得适当的 idx
维度,如
x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])
# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
[43.]])
这是使用 gather
在 PyTorch 中工作的那个。 idx
需要采用以下行将确保的 torch.int64
格式(注意 tensor
中 't' 的小写)。
idx = torch.tensor([[0],
[2]])
torch.gather(x, 1, idx) # 1 is the axis to index here
tensor([[10.],
[43.]])
我在 pytorch 中有一个张量 x 让我们说形状 (5,3,2,6) 和另一个张量 idx形状 (5,3,2,1),其中包含第一个张量中每个元素的索引。我想要用第二个张量的索引对第一个张量进行切片。我试过 x= x[idx] 但当我真的希望它的形状为 (5,3,2) 或 (5,3,2,1).
时,我得到了一个奇怪的维度我将尝试举一个更简单的例子: 比方说
x=torch.Tensor([[10,20,30],
[8,4,43]])
idx = torch.Tensor([[0],
[2]])
我想要类似的东西
y = x[idx]
这样 'y' 输出 [[10],[43]]
或类似的东西。
索引表示最后一维所需元素的位置。对于上面的示例,其中 x.shape = (2,3) 最后一个维度是列,然后 'idx' 中的索引是列。我想要这个但是超过 2 个维度
据我从评论中了解到,您需要 idx
作为最后一个维度的索引,并且 idx
中的每个索引对应于 x
中的类似索引(除了最后一个维度)。那样的话(这是numpy的版本,你可以把它转换成torch):
ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]
输出:
[[10]
[43]]
可以使用range
;和 squeeze
以获得适当的 idx
维度,如
x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])
# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
[43.]])
这是使用 gather
在 PyTorch 中工作的那个。 idx
需要采用以下行将确保的 torch.int64
格式(注意 tensor
中 't' 的小写)。
idx = torch.tensor([[0],
[2]])
torch.gather(x, 1, idx) # 1 is the axis to index here
tensor([[10.],
[43.]])