使用另一个张量选择 pytorch 张量的条目

Selecting entries of a pytorch tensor with another tensor

我有一个带有浮点条目的张量 atorch.Size([64,2]),我还有一个带有 torch.Size([64]) 的张量 bb的条目只有01.

我想得到一个新的张量 ctorch.Size([64]) 使得每个索引 i 的 c[i] == a[i,b[i]]。我该怎么做?

我的尝试 我尝试使用 torch.gather 但没有成功。下面的代码给我 RuntimeError: Index tensor must have the same number of dimensions as input tensor

import torch
a = torch.zeros([64,2])
b = torch.ones(64).long()
torch.gather(input=a, dim=1,index=b)

非常感谢任何帮助!

不确定我是否理解你的问题,但我认为你可以遍历你的张量

a = torch.zeros([64,2])
b = torch.ones(64).long()
c = torch.empty([64])
for i, _ in enumerate(a):
  c[i] = a[i,b[i]]
c

您可以在两个维度上使用索引 a 直接执行此操作:

  • dimension=0 上:使用 torch.arange 的“顺序”索引。

  • dimension=1 上:使用 b 进行索引。

总而言之,这给出了:

>>> a[torch.arange(len(a)), b]

或者您可以使用torch.gather,您要查找的操作是:

# c[i] == a[i,b[i]]

应用于 dim=1 时提供的收集操作提供如下内容:

# c[i,j] == a[i,b[i,j]]

如您所见,我们需要考虑 ab 之间的形状差异。为此,您可以取消压缩 b 上的单例维度(用上面的字母 j 注释),例如 #b=(64, 1),例如 b.unsqueeze(-1)b[...,None] :

>>> a.gather(dim=1, index=b[...,None]).flatten()