使用另一个张量选择 pytorch 张量的条目
Selecting entries of a pytorch tensor with another tensor
我有一个带有浮点条目的张量 a
和 torch.Size([64,2])
,我还有一个带有 torch.Size([64])
的张量 b
。 b
的条目只有0
或1
.
我想得到一个新的张量 c
和 torch.Size([64])
使得每个索引 i 的 c[i] == a[i,b[i]]
。我该怎么做?
我的尝试
我尝试使用 torch.gather
但没有成功。下面的代码给我 RuntimeError: Index tensor must have the same number of dimensions as input tensor
import torch
a = torch.zeros([64,2])
b = torch.ones(64).long()
torch.gather(input=a, dim=1,index=b)
非常感谢任何帮助!
不确定我是否理解你的问题,但我认为你可以遍历你的张量
a = torch.zeros([64,2])
b = torch.ones(64).long()
c = torch.empty([64])
for i, _ in enumerate(a):
c[i] = a[i,b[i]]
c
您可以在两个维度上使用索引 a
直接执行此操作:
在 dimension=0
上:使用 torch.arange
的“顺序”索引。
在 dimension=1
上:使用 b
进行索引。
总而言之,这给出了:
>>> a[torch.arange(len(a)), b]
或者您可以使用torch.gather
,您要查找的操作是:
# c[i] == a[i,b[i]]
应用于 dim=1
时提供的收集操作提供如下内容:
# c[i,j] == a[i,b[i,j]]
如您所见,我们需要考虑 a
和 b
之间的形状差异。为此,您可以取消压缩 b
上的单例维度(用上面的字母 j
注释),例如 #b=(64, 1)
,例如 b.unsqueeze(-1)
或 b[...,None]
:
>>> a.gather(dim=1, index=b[...,None]).flatten()
我有一个带有浮点条目的张量 a
和 torch.Size([64,2])
,我还有一个带有 torch.Size([64])
的张量 b
。 b
的条目只有0
或1
.
我想得到一个新的张量 c
和 torch.Size([64])
使得每个索引 i 的 c[i] == a[i,b[i]]
。我该怎么做?
我的尝试
我尝试使用 torch.gather
但没有成功。下面的代码给我 RuntimeError: Index tensor must have the same number of dimensions as input tensor
import torch
a = torch.zeros([64,2])
b = torch.ones(64).long()
torch.gather(input=a, dim=1,index=b)
非常感谢任何帮助!
不确定我是否理解你的问题,但我认为你可以遍历你的张量
a = torch.zeros([64,2])
b = torch.ones(64).long()
c = torch.empty([64])
for i, _ in enumerate(a):
c[i] = a[i,b[i]]
c
您可以在两个维度上使用索引 a
直接执行此操作:
在
dimension=0
上:使用torch.arange
的“顺序”索引。在
dimension=1
上:使用b
进行索引。
总而言之,这给出了:
>>> a[torch.arange(len(a)), b]
或者您可以使用torch.gather
,您要查找的操作是:
# c[i] == a[i,b[i]]
应用于 dim=1
时提供的收集操作提供如下内容:
# c[i,j] == a[i,b[i,j]]
如您所见,我们需要考虑 a
和 b
之间的形状差异。为此,您可以取消压缩 b
上的单例维度(用上面的字母 j
注释),例如 #b=(64, 1)
,例如 b.unsqueeze(-1)
或 b[...,None]
:
>>> a.gather(dim=1, index=b[...,None]).flatten()