pytorch - torch.gather 的倒数

pytorch - reciprocal of torch.gather

给定一个输入张量 x 和一个索引张量 idxs,我想检索 x 中索引不在 idxs 中的所有元素。也就是说,取 torch.gather 函数输出的相反值。

示例torch.gather

>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1,  2,  3],
        [14, 15, 16],
        [27, 28, 29]])

我确实想要达到的是

tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

什么是有效和高效的实施,可能使用 torch 实用程序?我不想使用任何 for 循环。

我假设 idxs 在其最深维度中只有 unique 元素。例如 idxs 将是调用 torch.topk.

的结果

您可能希望构建形状为 (x.size(0), x.size(1)-idxs.size(1)) 的张量(此处为 (3, 7))。这将对应于 idxs 的互补索引,关于 x 的形状,即 :

tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

我建议首先构建一个形状像 x 的张量,它可以揭示我们想要保留的位置和我们想要丢弃的位置,一种掩码。这可以使用 torch.scatter 来完成。这基本上将 0s 分散在所需位置,即 m[i, idxs[i][j]] = 0:

>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])

然后抓取非零(idxs的互补部分)。 Select axis=1 上的第二个索引,并根据目标张量进行整形:

>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

现在你知道该怎么做了吧?与您给出的 torch.gather 示例相同,但这次使用 idxs_:

>>> torch.gather(x, 1, idxs_)
tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

总结:

>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
        .nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))

>>> torch.gather(x, 1, idxs_)