如何以 GPU 友好的方式获取二维张量中多个元素的索引?
How to get indices of multiple elements in a 2D tensor, in a GPU friendly way?
此问题与已回答的问题类似 ,但该问题并未解决如何检索多个元素的索引。
我有一个二维张量 points
,它有很多行和少量列,我想得到一个包含该张量中所有元素的行索引的张量。我事先知道 points
中有哪些元素;它包含从 0 到 999 的整数元素,我可以使用范围函数制作一个张量来反映可能元素的集合。元素可以在任何列中。
如何以避免循环或使用 numpy 的方式检索每个元素出现在我的张量中的行索引,以便我可以在 GPU 上快速执行此操作?
我正在寻找类似 (points == elements).nonzero()[:,1]
的内容
谢谢!
尝试torch.cat([(t == i).nonzero() for i in elements_to_compare])
>>> import torch
>>> t = torch.empty((15,4)).random_(0, 999)
>>> t
tensor([[429., 833., 393., 828.],
[555., 893., 846., 909.],
[ 11., 861., 586., 222.],
[232., 92., 576., 452.],
[171., 341., 851., 953.],
[ 94., 46., 130., 413.],
[243., 251., 545., 331.],
[620., 29., 194., 176.],
[303., 905., 771., 149.],
[482., 225., 7., 315.],
[ 44., 547., 206., 299.],
[695., 7., 645., 385.],
[225., 898., 677., 693.],
[746., 21., 505., 875.],
[591., 254., 84., 888.]])
>>> torch.cat([(t == i).nonzero() for i in [7,385]])
tensor([[ 9, 2],
[11, 1],
[11, 3]])
>>> torch.cat([(t == i).nonzero()[:,1] for i in [7,385]])
tensor([2, 1, 3])
Numpy:
>>> np.nonzero(np.isin(t, [7,385]))
(array([ 9, 11, 11], dtype=int64), array([2, 1, 3], dtype=int64))
>>> np.nonzero(np.isin(t, [7,385]))[1]
array([2, 1, 3], dtype=int64)
我不确定我是否正确理解了您要查找的内容,但如果您想要某个值的索引,您可以尝试使用 where
和结果的稀疏表示。
例如在下面的张量 points
中,值 998
出现在索引 [0,0]
和 [2,0]
处。要获得这些指数,可以:
In [34]: points=torch.tensor([ [998, 6], [1, 3], [998, 999], [2, 3] ] )
In [35]: torch.where(points==998, points, torch.tensor(0)).to_sparse().indices()
Out[35]:
tensor([[0, 2],
[0, 0]])
此问题与已回答的问题类似
我有一个二维张量 points
,它有很多行和少量列,我想得到一个包含该张量中所有元素的行索引的张量。我事先知道 points
中有哪些元素;它包含从 0 到 999 的整数元素,我可以使用范围函数制作一个张量来反映可能元素的集合。元素可以在任何列中。
如何以避免循环或使用 numpy 的方式检索每个元素出现在我的张量中的行索引,以便我可以在 GPU 上快速执行此操作?
我正在寻找类似 (points == elements).nonzero()[:,1]
谢谢!
尝试torch.cat([(t == i).nonzero() for i in elements_to_compare])
>>> import torch
>>> t = torch.empty((15,4)).random_(0, 999)
>>> t
tensor([[429., 833., 393., 828.],
[555., 893., 846., 909.],
[ 11., 861., 586., 222.],
[232., 92., 576., 452.],
[171., 341., 851., 953.],
[ 94., 46., 130., 413.],
[243., 251., 545., 331.],
[620., 29., 194., 176.],
[303., 905., 771., 149.],
[482., 225., 7., 315.],
[ 44., 547., 206., 299.],
[695., 7., 645., 385.],
[225., 898., 677., 693.],
[746., 21., 505., 875.],
[591., 254., 84., 888.]])
>>> torch.cat([(t == i).nonzero() for i in [7,385]])
tensor([[ 9, 2],
[11, 1],
[11, 3]])
>>> torch.cat([(t == i).nonzero()[:,1] for i in [7,385]])
tensor([2, 1, 3])
Numpy:
>>> np.nonzero(np.isin(t, [7,385]))
(array([ 9, 11, 11], dtype=int64), array([2, 1, 3], dtype=int64))
>>> np.nonzero(np.isin(t, [7,385]))[1]
array([2, 1, 3], dtype=int64)
我不确定我是否正确理解了您要查找的内容,但如果您想要某个值的索引,您可以尝试使用 where
和结果的稀疏表示。
例如在下面的张量 points
中,值 998
出现在索引 [0,0]
和 [2,0]
处。要获得这些指数,可以:
In [34]: points=torch.tensor([ [998, 6], [1, 3], [998, 999], [2, 3] ] )
In [35]: torch.where(points==998, points, torch.tensor(0)).to_sparse().indices()
Out[35]:
tensor([[0, 2],
[0, 0]])