对于给定的条件,获取 2D 张量 A 中值的索引,使用它们来索引 3D 张量 B
For a given condition, get indices of values in 2D tensor A, use those to index a 3D tensor B
对于给定的二维张量,我想检索值为 1
的所有索引。我希望能够简单地使用 torch.nonzero(a == 1).squeeze()
,这将 return tensor([1, 3, 2])
。但是,torch.nonzero(a == 1)
return 是一个二维张量(没关系),每行有两个值(这不是我所期望的)。然后应该使用 returned 索引来索引 3D 张量的第二维(索引 1),再次 returning 2D 张量。
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])
print(b.gather(1, idxs))
显然,这不起作用,导致运行时错误:
RuntimeError: invalid argument 4: Index tensor must have same
dimensions as input tensor at
C:\w\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453
看来idxs
并不是我想的那样,也不能按照我想的那样使用。 idxs
是
tensor([[0, 1],
[1, 3],
[2, 2]])
但是通读 documentation 我不明白为什么我还要取回结果张量中的行索引。现在,我知道我可以通过切片 idxs[:, 1]
获得正确的 idxs,但是我仍然不能将这些值用作 3D 张量的索引,因为会出现与之前相同的错误。是否可以将索引的一维张量用于 select 个给定维度的项目?
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
#idxs = torch.nonzero(a == 1, as_tuple=True)
idxs = torch.nonzero(a == 1)
#print('idxs_size', idxs.size())
print(torch.index_select(b,1,idxs[:,1]))
假设b
的三个维度为batch_size x sequence_length x features
(b x s x feats),可以实现如下预期结果
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print(b.size())
# b x s x feats
idxs = torch.nonzero(a == 1)[:, 1]
print(idxs.size())
# b
c = b[torch.arange(b.size(0)), idxs]
print(c.size())
# b x feats
您可以简单地将它们切片并将其作为索引传递,如下所示:
In [193]: idxs = torch.nonzero(a == 1)
In [194]: c = b[idxs[:, 0], idxs[:, 1]]
In [195]: c
Out[195]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
或者,一种更简单且我更喜欢的方法是只使用 torch.where()
然后直接索引到张量 b
中,如:
In [196]: b[torch.where(a == 1)]
Out[196]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
关于上述使用 torch.where()
的方法的更多解释:它基于 advanced indexing 的概念工作。也就是说,当我们使用序列对象的元组对张量进行索引时,例如张量元组、列表元组、元组元组等
# some input tensor
In [207]: a
Out[207]:
tensor([[12., 1., 0., 0.],
[ 4., 9., 21., 1.],
[10., 2., 1., 0.]])
对于基本切片,我们需要一个整数索引元组:
In [212]: a[(1, 2)]
Out[212]: tensor(21.)
要使用高级索引实现相同的目的,我们需要一个序列对象元组:
# adv. indexing using a tuple of lists
In [213]: a[([1,], [2,])]
Out[213]: tensor([21.])
# adv. indexing using a tuple of tuples
In [215]: a[((1,), (2,))]
Out[215]: tensor([21.])
# adv. indexing using a tuple of tensors
In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))]
Out[214]: tensor([21.])
并且返回张量的维度总是比输入张量的维度小一维。
作为@kmario23方案的补充,你仍然可以达到和
一样的效果
b[torch.nonzero(a==1,as_tuple=True)]
对于给定的二维张量,我想检索值为 1
的所有索引。我希望能够简单地使用 torch.nonzero(a == 1).squeeze()
,这将 return tensor([1, 3, 2])
。但是,torch.nonzero(a == 1)
return 是一个二维张量(没关系),每行有两个值(这不是我所期望的)。然后应该使用 returned 索引来索引 3D 张量的第二维(索引 1),再次 returning 2D 张量。
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])
print(b.gather(1, idxs))
显然,这不起作用,导致运行时错误:
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:\w\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453
看来idxs
并不是我想的那样,也不能按照我想的那样使用。 idxs
是
tensor([[0, 1],
[1, 3],
[2, 2]])
但是通读 documentation 我不明白为什么我还要取回结果张量中的行索引。现在,我知道我可以通过切片 idxs[:, 1]
获得正确的 idxs,但是我仍然不能将这些值用作 3D 张量的索引,因为会出现与之前相同的错误。是否可以将索引的一维张量用于 select 个给定维度的项目?
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
#idxs = torch.nonzero(a == 1, as_tuple=True)
idxs = torch.nonzero(a == 1)
#print('idxs_size', idxs.size())
print(torch.index_select(b,1,idxs[:,1]))
假设b
的三个维度为batch_size x sequence_length x features
(b x s x feats),可以实现如下预期结果
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print(b.size())
# b x s x feats
idxs = torch.nonzero(a == 1)[:, 1]
print(idxs.size())
# b
c = b[torch.arange(b.size(0)), idxs]
print(c.size())
# b x feats
您可以简单地将它们切片并将其作为索引传递,如下所示:
In [193]: idxs = torch.nonzero(a == 1)
In [194]: c = b[idxs[:, 0], idxs[:, 1]]
In [195]: c
Out[195]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
或者,一种更简单且我更喜欢的方法是只使用 torch.where()
然后直接索引到张量 b
中,如:
In [196]: b[torch.where(a == 1)]
Out[196]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
关于上述使用 torch.where()
的方法的更多解释:它基于 advanced indexing 的概念工作。也就是说,当我们使用序列对象的元组对张量进行索引时,例如张量元组、列表元组、元组元组等
# some input tensor
In [207]: a
Out[207]:
tensor([[12., 1., 0., 0.],
[ 4., 9., 21., 1.],
[10., 2., 1., 0.]])
对于基本切片,我们需要一个整数索引元组:
In [212]: a[(1, 2)]
Out[212]: tensor(21.)
要使用高级索引实现相同的目的,我们需要一个序列对象元组:
# adv. indexing using a tuple of lists
In [213]: a[([1,], [2,])]
Out[213]: tensor([21.])
# adv. indexing using a tuple of tuples
In [215]: a[((1,), (2,))]
Out[215]: tensor([21.])
# adv. indexing using a tuple of tensors
In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))]
Out[214]: tensor([21.])
并且返回张量的维度总是比输入张量的维度小一维。
作为@kmario23方案的补充,你仍然可以达到和
一样的效果b[torch.nonzero(a==1,as_tuple=True)]