在 PyTorch 中索引多维张量中的最大元素

Indexing the max elements in a multidimensional tensor in PyTorch

我正在尝试对多维张量中最后一个维度的最大元素进行索引。例如,假设我有一个张量

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

这里 idx 存储了最大的索引,可能看起来像

>>>> A
tensor([[[ 1.0503,  0.4448,  1.8663],
     [ 0.8627,  0.0685,  1.4241]],

    [[ 1.2924,  0.2456,  0.1764],
     [ 1.3777,  0.9401,  1.4637]],

    [[ 0.5235,  0.4550,  0.2476],
     [ 0.7823,  0.3004,  0.7792]],

    [[ 1.9384,  0.3291,  0.7914],
     [ 0.5211,  0.1320,  0.6330]],

    [[ 0.3292,  0.9086,  0.0078],
     [ 1.3612,  0.0610,  0.4023]]])
>>>> idx
tensor([[ 2,  2],
    [ 0,  2],
    [ 0,  0],
    [ 0,  2],
    [ 1,  0]])

我希望能够访问这些索引并根据它们分配给另一个张量。意思是我希望能够做到

B = torch.new_zeros(A.size())
B[idx] = A[idx]

其中 B 在所有地方都为 0,除了 A 在最后一个维度上最大的地方。即B应该存储

>>>>B
tensor([[[ 0,  0,  1.8663],
     [ 0,  0,  1.4241]],

    [[ 1.2924,  0,  0],
     [ 0,  0,  1.4637]],

    [[ 0.5235,  0,  0],
     [ 0.7823,  0,  0]],

    [[ 1.9384,  0,  0],
     [ 0,  0,  0.6330]],

    [[ 0,  0.9086,  0],
     [ 1.3612,  0,  0]]])

事实证明这比我预期的要困难得多,因为 idx 没有正确索引数组 A。到目前为止,我一直无法找到使用 idx 来索引 A 的矢量化解决方案。

有没有好的矢量化方法来做到这一点?

一个丑陋的变通方法是从 idx 中创建一个二进制掩码并用它来索引数组。基本代码如下所示:

import torch
torch.manual_seed(0)

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)

诀窍是 torch.arange(A.size(2)) 枚举 idx 中的可能值,而 mask 在它们等于 idx 的地方是非零的。备注:

  1. 如果你真的舍弃了torch.max的第一个输出,你可以用torch.argmax代替。
  2. 我假设这是一些更广泛问题的最小示例,但请注意,您目前正在使用大小为 (1, 1, 3).
  3. 的内核重新发明 torch.nn.functional.max_pool3d
  4. 此外,请注意,使用掩码赋值对张量进行就地修改可能会导致 autograd 出现问题,因此您可能需要使用 torch.where,如 所示。

我希望有人想出一个更简洁的解决方案(避免 mask 数组的中间分配),可能会使用 torch.index_select,但我无法让它工作现在。

您可以使用torch.meshgrid创建索引元组:

>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]

请注意,您还可以通过(对于 3D 的特定情况)模仿 meshgrid

>>> index_tuple = (
...     torch.arange(A.size(0))[:, None],
...     torch.arange(A.size(1))[None, :],
...     idx
... )

再解释一下:
我们会有这样的索引:

In [173]: idx 
Out[173]: 
tensor([[2, 1],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 2]])

据此,我们想要转到三个索引(因为我们的张量是 3D,我们需要三个数字来检索每个元素)。基本上我们想在前两个维度构建一个网格,如下所示。 (这就是我们使用 meshgrid 的原因)。

In [174]: A[0, 0, 2], A[0, 1, 1]  
Out[174]: (tensor(0.6288), tensor(-0.3070))

In [175]: A[1, 0, 2], A[1, 1, 0]  
Out[175]: (tensor(1.7085), tensor(0.7818))

In [176]: A[2, 0, 2], A[2, 1, 1]  
Out[176]: (tensor(0.4823), tensor(1.1199))

In [177]: A[3, 0, 2], A[3, 1, 2]    
Out[177]: (tensor(1.6903), tensor(1.0800))

In [178]: A[4, 0, 2], A[4, 1, 2]          
Out[178]: (tensor(0.9138), tensor(0.1779))

在上面的5行中,索引中的前两个数字基本上是我们使用meshgrid构建的网格,第三个数字来自idx

即前两个数字组成一个网格。

 (0, 0) (0, 1)
 (1, 0) (1, 1)
 (2, 0) (2, 1)
 (3, 0) (3, 1)
 (4, 0) (4, 1)

可以使用 torch.scatter here

>>> import torch
>>> a = torch.randn(4,2,3)
>>> a
tensor([[[ 0.1583,  0.1102, -0.8188],
         [ 0.6328, -1.9169, -0.5596]],

        [[ 0.5335,  0.4069,  0.8403],
         [-1.2537,  0.9868, -0.4947]],

        [[-1.2830,  0.4386, -0.0107],
         [ 1.3384,  0.5651,  0.2877]],

        [[-0.0334, -1.0619, -0.1144],
         [ 0.1954, -0.7371,  1.7001]]])
>>> ind = torch.max(a,1,keepdims=True)[1]
>>> ind
tensor([[[1, 0, 1]],

        [[0, 1, 0]],

        [[1, 1, 1]],

        [[1, 1, 1]]])
>>> torch.zeros_like(a).scatter(1,ind,a)
tensor([[[ 0.0000,  0.1102,  0.0000],
         [ 0.1583,  0.0000, -0.8188]],

        [[ 0.5335,  0.0000,  0.8403],
         [ 0.0000,  0.4069,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [-1.2830,  0.4386, -0.0107]],

        [[ 0.0000,  0.0000,  0.0000],
         [-0.0334, -1.0619, -0.1144]]])