通过匹配从张量中删除行

Remove rows from tensor by matching

我正在尝试做一些操作,比如 pytorch 中是否有张量

a = torch.tensor([[1,0]
                  ,[0,1]
                  ,[2,0]
                  ,[3,2]])


b = torch.tensor([[0,1]
                  ,[2,0]])

我想从 a.

中删除行 [0,1], [2,0],它们是 b 的行

有什么办法吗?

# result
a = torch.tensor([[1,0]
                  ,[3,2]])

如果张量形状是可广播的,你可以做到这一点。

对于形状 (?, d) 的张量 a 和形状 (d,) 的张量 b,你可以这样写:

cmp = a.eq(b).all(dim=1).logical_not(),即将 a 的每个 d 维行与 b 进行比较,并给出比较为 False 的索引。

从这些你可以很容易地像这样你的新张量: a = a[cmp]

b 本身包含批次维度时,我怀疑您会找到一种优雅的方式来执行此操作;你最好的选择是写一个 for 循环。

完整示例:

>>> xs = torch.tensor([[1,0], [0,1], [2,0], [3,2]])
>>> ys = torch.tensor([[0,1],[2,0]])
>>> for y in ys:
...     xs = xs[xs.eq(y).all(dim=1).logical_not()]
>>> xs
tensor([[1, 0],
        [3, 2]])


你可以利用广播做这样的事情:

import torch
a = torch.tensor([[1, 0], [0, 1], [2, 0], [3, 2]])
b = torch.tensor([[0, 1], [2, 0]])

indices = ((a == b[:, None]).sum(axis = 2) != a.shape[1]).all(axis = 0)
print(indices)
print(a[indices])

指数=

tensor([ True, False, False,  True])

a[指数] =

tensor([[1, 0],
        [3, 2]])

适用于形状为 m x np x n 的所有张量 ab,即数字列数 (a.shape[1]) 必须相同,您可以在任何列之间进行比较。行数。