如何仅在 Pytorch 张量的选定索引中找到 argmax/argmin
How to find argmax/argmin in only selected indices of a Pytorch tensor
我有一个距离张量
tensor([ 5, 10, 2, 3, 4], device='cuda:0')
还有一个指数张量
tensor([ 0, 2, 3], device='cuda:0')
我想找到距离张量的argmax,但只在索引张量指定的索引的子集上。
在这个例子中,我将查看距离张量的第 0、2、3 个元素(值 5、2、3)并返回索引 0(最大值 - 5 在距离中的第 0 位张量)
tensor([ 0], device='cuda:0')
如果不使用 for 循环,这样的事情是否可行?
谢谢
举个例子。您可以检查所选项目子集的最大 dist
值是否位于索引零处,并且最终输出张量也包含值零。请注意,当我们使用一维张量时,torch.index_select
中的 dim
参数为零。
import torch
dist = torch.randn(5, 1)
#tensor([[ 0.3392],
# [ 0.4472],
# [ 0.1398],
# [-1.0379],
# [ 0.2950]])
idx = torch.tensor([0,2,3])
#tensor([0, 2, 3])
仅使用 max
函数和张量过滤:
max_val = torch.max(torch.index_select(dist, 0, idx)).item()
#0.33918169140815735
(dist == max_val).nonzero(as_tuple=True)[0]
#tensor([0])
我有一个距离张量
tensor([ 5, 10, 2, 3, 4], device='cuda:0')
还有一个指数张量
tensor([ 0, 2, 3], device='cuda:0')
我想找到距离张量的argmax,但只在索引张量指定的索引的子集上。
在这个例子中,我将查看距离张量的第 0、2、3 个元素(值 5、2、3)并返回索引 0(最大值 - 5 在距离中的第 0 位张量)
tensor([ 0], device='cuda:0')
如果不使用 for 循环,这样的事情是否可行? 谢谢
举个例子。您可以检查所选项目子集的最大 dist
值是否位于索引零处,并且最终输出张量也包含值零。请注意,当我们使用一维张量时,torch.index_select
中的 dim
参数为零。
import torch
dist = torch.randn(5, 1)
#tensor([[ 0.3392],
# [ 0.4472],
# [ 0.1398],
# [-1.0379],
# [ 0.2950]])
idx = torch.tensor([0,2,3])
#tensor([0, 2, 3])
仅使用 max
函数和张量过滤:
max_val = torch.max(torch.index_select(dist, 0, idx)).item()
#0.33918169140815735
(dist == max_val).nonzero(as_tuple=True)[0]
#tensor([0])