在 PyTorch 中用张量索引多维张量

Indexing a multi-dimensional tensor with a tensor in PyTorch

我有以下代码:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

我有一个多维索引 b 并想将其用于 select a 中的单个单元格。如果 b 不是张量,我可以这样做:

a[1,1,1,1]

哪个 returns 是正确的单元格,但是:

a[b]

不起作用,因为它只是 selects a[1] 四次。

我该怎么做?谢谢

您可以使用 chunkb 拆分为 4 个,然后使用分块 b 来索引您想要的特定元素:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

它的好处是它可以很容易地推广到a的任何维度,你只需要让卡盘的数量等于a的维度。

一个更优雅(也更简单)的解决方案可能是简单地将 b 转换为一个元组:

a[tuple(b)]
Out[10]: tensor(5.)

我很想知道它是如何与 "regular" numpy 一起工作的,并找到了一篇解释得很好的相关文章 here