尺寸 M < 32 的 Pytorch 张量索引错误?

Pytorch tensor indexing error for sizes M < 32?

我正在尝试通过索引矩阵访问 pytorch 张量,我最近发现这段代码我找不到它不起作用的原因。

下面的代码分为两部分。前半部分证明有效,而后半部分出现错误。我看不出原因。有人能解释一下吗?

import torch
import numpy as np

a = torch.rand(32, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx]   # WORKS for a torch.tensor of size M >= 32. It doesn't work otherwise.

a = torch.rand(16, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx]   # IndexError: too many indices for tensor of dimension 2

如果我更改 a = np.random.rand(16, 16) 它也能正常工作。

首先,让我快速介绍一下使用 numpy 数组和另一个张量对张量进行索引的想法。

示例:这是我们要索引的目标张量

    numpy_indices = torch.tensor([[0, 1, 2, 7],
                                  [0, 1, 2, 3]])   # numpy array

    tensor_indices = torch.tensor([[0, 1, 2, 7],
                                   [0, 1, 2, 3]])   # 2D tensor

    t = torch.tensor([[1,  2,  3,   4],            # targeted tensor
                      [5,  6,  7,   8],
                      [9,  10, 11, 12],
                      [13, 14, 15, 16],
                      [17, 18, 19, 20],
                      [21, 22, 23, 24],
                      [25, 26, 27, 28],
                      [29, 30, 31, 32]])
     numpy_result = t[numpy_indices]
     tensor_result = t[tensor_indices]
  • 使用 2D numpy 数组进行索引:索引像对 (x,y) 张量 [行,列] 一样读取,例如t[0,0], t[1,1], t[2,2], and t[7,3].

    print(numpy_result)  # tensor([ 1,  6, 11, 32])
    
  • 使用二维张量进行索引:以 row-wise 方式遍历索引张量,每个值都是目标张量中一行的索引。 例如[ [t[0],t[1],t[2],[7]] , [[0],[1],[2],[3]] ]见下例,索引后tensor_result的新形状为(tensor_indices.shape[0],tensor_indices.shape[1],t.shape[1])=(2,4,4).

    print(tensor_result)     # tensor([[[ 1,  2,  3,  4],
                             #          [ 5,  6,  7,  8],
                             #          [ 9, 10, 11, 12],
                             #          [29, 30, 31, 32]],
    
                             #          [[ 1,  2,  3,  4],
                             #           [ 5,  6,  7,  8],
                             #           [ 9, 10, 11, 12],
                             #           [ 13, 14, 15, 16]]])
    

如果您尝试在 numpy_indices 中添加第三行,您将得到相同的错误,因为索引将由 3D 表示,例如 (0,0,0)...(7 ,3,3).

indices = np.array([[0, 1, 2, 7],
                    [0, 1, 2, 3],
                    [0, 1, 2, 3]])

print(numpy_result)   # IndexError: too many indices for tensor of dimension 2

但是张量索引不是这样,形状会更大(3,4,4)。

最后,如您所见,两种索引类型的输出完全不同。要解决您的问题,您可以使用

xx = torch.tensor(xx).long()  # convert a numpy array to a tensor

在高级索引的情况下(numpy_indices > 3 的行)会发生什么,因为您的情况仍然不明确且未解决,您可以检查 1 , 2, 3.

对于前来寻找答案的人:它看起来像是 pyTorch 中的一个错误。

使用 numpy 数组进行索引的定义不明确,如果使用张量对张量进行索引,它有效。所以,在我的示例代码中,这完美地工作:

a = torch.rand(M, N)
m, n = a.shape
xx, yy = torch.meshgrid(torch.arange(m), torch.arange(m), indexing='xy')
result = a[xx]   # WORKS 

我做了一个gist to check it, and it's available here