获取张量 a 中存在于张量 b 中的元素的索引

Get indices of elements in tensor a that are present in tensor b

例如,我想获取张量 a 中值为 0 和 2 的元素的索引。这些值(0 和 2)存储在张量 b 中。我已经设计了一种 pythonic 的方式来做到这一点(如下所示),但我认为列表推导式并未针对 GPU 上的 运行 进行优化,或者可能有更多我不知道的 PyTorchy 方式来做到这一点。

import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()

>>>> tensor([[0],
             [2],
             [5],
             [6]])

还有其他建议吗?或者这是可以接受的方式吗?

这是一种更有效的方法(如 jodag 在评论中发布的 link 所建议...):

(a[..., None] == b).any(-1).nonzero()