我可以提取与 pytorch 张量中某个键对应的所有索引吗?

Can I extract all indices that correspond to a certain key in a pytorch tensor?

假设我有一个 pytorch 张量 tensor([3,5,7,3,9,3,0])。我想提取 3 出现的索引,即 tensor([0,3,5])。有内置函数吗?

有专门的 function 用于此:

   torch.where(my_tensor == the_number)
t = torch.Tensor([1, 2, 3 , 2 , 5])
print ((t == 2).nonzero())

nonzero 打印火炬张量的所有非零位置 https://pytorch.org/docs/master/generated/torch.nonzero.html