检查张量的每个元素是否包含在列表中
Check if each element of a tensor is contained in a list
假设我有一个张量 A
和一个值容器 vals
。是否有一种干净的方法可以返回与 A
形状相同的布尔张量,每个元素都是 A
的元素是否包含在 vals
中?例如:
A = torch.tensor([[1,2,3],
[4,5,6]])
vals = [1,5]
# Desired output
torch.tensor([[True,False,False],
[False,True,False]])
您可以使用 for 循环实现此目的:
sum(A==i for i in B).bool()
你可以简单地这样做:
result = A.apply_(lambda x: x in vals).bool()
然后result
将包含这个张量:
tensor([[ True, False, False],
[False, True, False]])
我在这里简单地使用了一个 lambda 函数和 apply_ 方法,您可以在 official documentation.
中找到它们
[list(map(lambda x: x in vals, thelist)) for thelist in A]
假设我有一个张量 A
和一个值容器 vals
。是否有一种干净的方法可以返回与 A
形状相同的布尔张量,每个元素都是 A
的元素是否包含在 vals
中?例如:
A = torch.tensor([[1,2,3],
[4,5,6]])
vals = [1,5]
# Desired output
torch.tensor([[True,False,False],
[False,True,False]])
您可以使用 for 循环实现此目的:
sum(A==i for i in B).bool()
你可以简单地这样做:
result = A.apply_(lambda x: x in vals).bool()
然后result
将包含这个张量:
tensor([[ True, False, False],
[False, True, False]])
我在这里简单地使用了一个 lambda 函数和 apply_ 方法,您可以在 official documentation.
中找到它们[list(map(lambda x: x in vals, thelist)) for thelist in A]