基于另一个没有循环的张量过滤火炬张量

Filter torch tensor based on another tensor without loops

假设我有以下两个torch.Tensors:

x = torch.tensor([0,0,0,1,1,2,2,2,2], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)

我想以某种方式过滤 x,以便仅保留 y 中的值:

x_filtered = torch.tensor([0,0,0,2,2,2,2])

再比如,如果y = torch.tensor([0,1]),那么x_filtered = torch.tensor([0,0,0,1,1])x,y 始终是 1D 和 int64。 y 总是排序的,如果它使它更简单,我们可以假设 x 总是排序。

我想过各种不使用循环的方法,但都失败了。我不能真正使用循环,因为我的用例涉及数百万的 x 和数万的 y。感谢您的帮助。


刚刚意识到我需要的是相当于 numpy.in1d

的手电筒

答案是https://pytorch.org/docs/master/generated/torch.isin.html?highlight=isin#torch.isin:

>>> torch.isin(x,y)
tensor([ True,  True,  True, False, False,  True,  True,  True,  True])

要在任务中根据需要过滤张量,您需要使用 torch 中可用的 isin 函数。使用方法如下:-

import torch
x = torch.tensor([0,0,0,1,1,2,2,2,2,3], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)
# torch.isin(x, y)
c=x[torch.isin(x,y)]
print(c)

在运行此代码后,您将获得您的首选答案。