获取张量 a 中存在于张量 b 中的元素的索引
Get indices of elements in tensor a that are present in tensor b
例如,我想获取张量 a
中值为 0 和 2 的元素的索引。这些值(0 和 2)存储在张量 b
中。我已经设计了一种 pythonic 的方式来做到这一点(如下所示),但我认为列表推导式并未针对 GPU 上的 运行 进行优化,或者可能有更多我不知道的 PyTorchy 方式来做到这一点。
import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()
>>>> tensor([[0],
[2],
[5],
[6]])
还有其他建议吗?或者这是可以接受的方式吗?
这是一种更有效的方法(如 jodag 在评论中发布的 link 所建议...):
(a[..., None] == b).any(-1).nonzero()
例如,我想获取张量 a
中值为 0 和 2 的元素的索引。这些值(0 和 2)存储在张量 b
中。我已经设计了一种 pythonic 的方式来做到这一点(如下所示),但我认为列表推导式并未针对 GPU 上的 运行 进行优化,或者可能有更多我不知道的 PyTorchy 方式来做到这一点。
import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()
>>>> tensor([[0],
[2],
[5],
[6]])
还有其他建议吗?或者这是可以接受的方式吗?
这是一种更有效的方法(如 jodag 在评论中发布的 link 所建议...):
(a[..., None] == b).any(-1).nonzero()