使用索引索引 Tensor 的第二维

Indexing second dimension of Tensor using indices

我 select 使用索引张量编辑了张量中的元素。下面的代码我使用索引列表 0, 3, 2, 1 到 select 11, 15, 2, 5

>>> import torch
>>> a = torch.Tensor([5,2,11, 15])
>>> torch.randperm(4)

 0
 3
 2
 1
[torch.LongTensor of size 4]

>>> i = torch.randperm(4)
>>> a[i]

 11
 15
  2
  5
[torch.FloatTensor of size 4]

现在,我有

>>> b = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> b

  5   2  11  15
  5   2  11  15
  5   2  11  15
[torch.FloatTensor of size 3x4]

现在,我想对 select 列 0、3、2、1 使用索引。换句话说,我想要这样的张量

>>> b

 11  15   2   5
 11  15   2   5
 11  15   2   5
[torch.FloatTensor of size 3x4]

如果使用pytorch v0.1.12版本

对于这个版本,没有简单的方法可以做到这一点。尽管 pytorch 承诺张量操作与 numpy 完全一样,但仍然缺乏一些功能。这是其中之一。

如果您使用的是 numpy 数组,通常可以相对轻松地完成此操作。像这样。

>>> i = [2, 1, 0, 3]
>>> a = np.array([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:, i]

array([[11,  2,  5, 15],
       [11,  2,  5, 15],
       [11,  2,  5, 15]])

但是张量同样的事情会给你一个错误:

>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:,i]

错误:

TypeError: indexing a tensor with an object of type torch.LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.

TypeError 告诉你的是,如果你计划使用 LongTensor 或 ByteTensor 进行索引,那么唯一有效的语法是 a[<LongTensor>]a[<ByteTensor>]。除此以外的任何东西都不起作用。

由于此限制,您有两个选择:

选项 1: 转换为 numpy,置换,然后返回 Tensor

>>> i = [2, 1, 0, 3]
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> np_a = a.numpy()
>>> np_a = np_a[:,i]
>>> a = torch.from_numpy(np_a)
>>> a

 11   2   5  15
 11   2   5  15
 11   2   5  15
[torch.FloatTensor of size 3x4]

选项2:将你想要置换的dim移动到0然后再做

您将要排列的暗淡(在您的情况下为暗淡=1)移动到 0,执行排列,然后将其移回。它有点 hacky,但它完成了工作。

def hacky_permute(a, i, dim):
    a = torch.transpose(a, 0, dim)
    a = a[i]
    a = torch.transpose(a, 0, dim)
    return a

并像这样使用它:

>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a = hacky_permute(a, i, dim=1)
>>> a

 11   2   5  15
 11   2   5  15
 11   2   5  15
[torch.FloatTensor of size 3x4]

如果使用pytorch v0.2.0版本

使用张量的直接索引现在可以在此版本中使用。即

>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:,i]

 11   2   5  15
 11   2   5  15
 11   2   5  15
[torch.FloatTensor of size 3x4]