使用没有循环的坐标张量对 pytorch 张量进行切片

Slice pytorch tensor using coordinates tensor without loop

我有一个维度为 (d1 x d2 x d3 x ... dk) 的张量 T 和一个维度为 (p x q) 的张量 I。这里,I包含T的坐标,但是q < kI的每一列对应T的一个维度。我有另一个张量 V,维度 p x di x ...dj,其中 sum([di, ..., dj]) = k - q。 (di, .., dj) 对应于 I 中缺失的维度。我需要执行 T[I] = V

使用 numpy 数组的此类问题的具体示例已发布 [1].

依赖numpy.index_exp[2] uses fancy indexing[3]。在 pytorch 的情况下,此类选项不可用。有没有其他方法可以在 pytorch 中模仿这个而不使用循环或将张量转换为 numpy 数组?

下面是演示:

import torch
t = torch.randn((32, 16, 60, 64)) # tensor

i0 = torch.randint(0, 32, (10, 1)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10, 1)).to(dtype=torch.long) # indexes for dim=2

i = torch.cat((i0, i2), 1) # indexes
v = torch.randn((10, 16, 64)) # to be assigned

# t[i0, :, i2, :] = v ?? Obviously this does not work

[1]

[2]

[3] https://numpy.org/doc/stable/reference/generated/numpy.s_.html

经过评论区的讨论,我们得出以下解决方案:

import torch
t = torch.randn((32, 16, 60, 64)) # tensor

# indices
i0 = torch.randint(0, 32, (10,)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10,)).to(dtype=torch.long) # indexes for dim=2

v = torch.randn((10, 16, 64)) # to be assigned

t[(i0, slice(None), i2, slice(None))] = v