pytorch - torch.gather 的倒数
pytorch - reciprocal of torch.gather
给定一个输入张量 x
和一个索引张量 idxs
,我想检索 x
中索引不在 idxs
中的所有元素。也就是说,取 torch.gather
函数输出的相反值。
示例torch.gather
:
>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1, 2, 3],
[14, 15, 16],
[27, 28, 29]])
我确实想要达到的是
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
什么是有效和高效的实施,可能使用 torch 实用程序?我不想使用任何 for 循环。
我假设 idxs
在其最深维度中只有 unique 元素。例如 idxs
将是调用 torch.topk
.
的结果
您可能希望构建形状为 (x.size(0), x.size(1)-idxs.size(1))
的张量(此处为 (3, 7)
)。这将对应于 idxs
的互补索引,关于 x
、 的形状,即 :
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
我建议首先构建一个形状像 x
的张量,它可以揭示我们想要保留的位置和我们想要丢弃的位置,一种掩码。这可以使用 torch.scatter
来完成。这基本上将 0
s 分散在所需位置,即 m[i, idxs[i][j]] = 0
:
>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
然后抓取非零(idxs
的互补部分)。 Select axis=1
上的第二个索引,并根据目标张量进行整形:
>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
现在你知道该怎么做了吧?与您给出的 torch.gather
示例相同,但这次使用 idxs_
:
>>> torch.gather(x, 1, idxs_)
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
总结:
>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
>>> torch.gather(x, 1, idxs_)
给定一个输入张量 x
和一个索引张量 idxs
,我想检索 x
中索引不在 idxs
中的所有元素。也就是说,取 torch.gather
函数输出的相反值。
示例torch.gather
:
>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1, 2, 3],
[14, 15, 16],
[27, 28, 29]])
我确实想要达到的是
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
什么是有效和高效的实施,可能使用 torch 实用程序?我不想使用任何 for 循环。
我假设 idxs
在其最深维度中只有 unique 元素。例如 idxs
将是调用 torch.topk
.
您可能希望构建形状为 (x.size(0), x.size(1)-idxs.size(1))
的张量(此处为 (3, 7)
)。这将对应于 idxs
的互补索引,关于 x
、 的形状,即 :
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
我建议首先构建一个形状像 x
的张量,它可以揭示我们想要保留的位置和我们想要丢弃的位置,一种掩码。这可以使用 torch.scatter
来完成。这基本上将 0
s 分散在所需位置,即 m[i, idxs[i][j]] = 0
:
>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
然后抓取非零(idxs
的互补部分)。 Select axis=1
上的第二个索引,并根据目标张量进行整形:
>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
现在你知道该怎么做了吧?与您给出的 torch.gather
示例相同,但这次使用 idxs_
:
>>> torch.gather(x, 1, idxs_)
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
总结:
>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
>>> torch.gather(x, 1, idxs_)