如何仅在 Pytorch 张量的选定索引中找到 argmax/argmin

How to find argmax/argmin in only selected indices of a Pytorch tensor

我有一个距离张量

tensor([ 5,  10,  2,  3,  4], device='cuda:0')

还有一个指数张量

tensor([ 0,  2,  3], device='cuda:0')

我想找到距离张量的argmax,但只在索引张量指定的索引的子集

在这个例子中,我将查看距离张量的第 0、2、3 个元素(值 5、2、3)并返回索引 0(最大值 - 5 在距离中的第 0 位张量)

tensor([ 0], device='cuda:0')

如果不使用 for 循环,这样的事情是否可行? 谢谢

举个例子。您可以检查所选项目子集的最大 dist 值是否位于索引零处,并且最终输出张量也包含值零。请注意,当我们使用一维张量时,torch.index_select 中的 dim 参数为零。

import torch

dist = torch.randn(5, 1)
#tensor([[ 0.3392],
#        [ 0.4472],
#        [ 0.1398],
#        [-1.0379],
#        [ 0.2950]])


idx = torch.tensor([0,2,3])
#tensor([0, 2, 3])

仅使用 max 函数和张量过滤:

max_val = torch.max(torch.index_select(dist, 0, idx)).item()
#0.33918169140815735
(dist == max_val).nonzero(as_tuple=True)[0]
#tensor([0])