如何以 GPU 友好的方式获取二维张量中多个元素的索引?

How to get indices of multiple elements in a 2D tensor, in a GPU friendly way?

此问题与已回答的问题类似 ,但该问题并未解决如何检索多个元素的索引。

我有一个二维张量 points,它有很多行和少量列,我想得到一个包含该张量中所有元素的行索引的张量。我事先知道 points 中有哪些元素;它包含从 0 到 999 的整数元素,我可以使用范围函数制作一个张量来反映可能元素的集合。元素可以在任何列中。

如何以避免循环或使用 numpy 的方式检索每个元素出现在我的张量中的行索引,以便我可以在 GPU 上快速执行此操作?

我正在寻找类似 (points == elements).nonzero()[:,1]

的内容

谢谢!

尝试torch.cat([(t == i).nonzero() for i in elements_to_compare])

>>> import torch
>>> t = torch.empty((15,4)).random_(0, 999)
>>> t
tensor([[429., 833., 393., 828.],
        [555., 893., 846., 909.],
        [ 11., 861., 586., 222.],
        [232.,  92., 576., 452.],
        [171., 341., 851., 953.],
        [ 94.,  46., 130., 413.],
        [243., 251., 545., 331.],
        [620.,  29., 194., 176.],
        [303., 905., 771., 149.],
        [482., 225.,   7., 315.],
        [ 44., 547., 206., 299.],
        [695.,   7., 645., 385.],
        [225., 898., 677., 693.],
        [746.,  21., 505., 875.],
        [591., 254.,  84., 888.]])
>>> torch.cat([(t == i).nonzero() for i in [7,385]])
tensor([[ 9,  2],
        [11,  1],
        [11,  3]])

>>> torch.cat([(t == i).nonzero()[:,1] for i in [7,385]])
tensor([2, 1, 3])

Numpy:

>>> np.nonzero(np.isin(t, [7,385]))
(array([ 9, 11, 11], dtype=int64), array([2, 1, 3], dtype=int64))

>>> np.nonzero(np.isin(t, [7,385]))[1]
array([2, 1, 3], dtype=int64)

我不确定我是否正确理解了您要查找的内容,但如果您想要某个值的索引,您可以尝试使用 where 和结果的稀疏表示。

例如在下面的张量 points 中,值 998 出现在索引 [0,0][2,0] 处。要获得这些指数,可以:

In [34]: points=torch.tensor([ [998,  6], [1, 3], [998, 999], [2, 3] ] )

In [35]: torch.where(points==998, points, torch.tensor(0)).to_sparse().indices()
Out[35]:
tensor([[0, 2],
        [0, 0]])