使用没有循环的坐标张量对 pytorch 张量进行切片
Slice pytorch tensor using coordinates tensor without loop
我有一个维度为 (d1 x d2 x d3 x ... dk
) 的张量 T
和一个维度为 (p x q
) 的张量 I
。这里,I
包含T
的坐标,但是q < k
,I
的每一列对应T
的一个维度。我有另一个张量 V
,维度 p x di x ...dj
,其中 sum([di, ..., dj]) = k - q
。 (di, .., dj
) 对应于 I
中缺失的维度。我需要执行 T[I] = V
使用 numpy
数组的此类问题的具体示例已发布 [1].
依赖numpy.index_exp
的[2] uses fancy indexing[3]。在 pytorch
的情况下,此类选项不可用。有没有其他方法可以在 pytorch
中模仿这个而不使用循环或将张量转换为 numpy
数组?
下面是演示:
import torch
t = torch.randn((32, 16, 60, 64)) # tensor
i0 = torch.randint(0, 32, (10, 1)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10, 1)).to(dtype=torch.long) # indexes for dim=2
i = torch.cat((i0, i2), 1) # indexes
v = torch.randn((10, 16, 64)) # to be assigned
# t[i0, :, i2, :] = v ?? Obviously this does not work
[1]
[2]
[3] https://numpy.org/doc/stable/reference/generated/numpy.s_.html
经过评论区的讨论,我们得出以下解决方案:
import torch
t = torch.randn((32, 16, 60, 64)) # tensor
# indices
i0 = torch.randint(0, 32, (10,)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10,)).to(dtype=torch.long) # indexes for dim=2
v = torch.randn((10, 16, 64)) # to be assigned
t[(i0, slice(None), i2, slice(None))] = v
我有一个维度为 (d1 x d2 x d3 x ... dk
) 的张量 T
和一个维度为 (p x q
) 的张量 I
。这里,I
包含T
的坐标,但是q < k
,I
的每一列对应T
的一个维度。我有另一个张量 V
,维度 p x di x ...dj
,其中 sum([di, ..., dj]) = k - q
。 (di, .., dj
) 对应于 I
中缺失的维度。我需要执行 T[I] = V
使用 numpy
数组的此类问题的具体示例已发布
依赖numpy.index_exp
的pytorch
的情况下,此类选项不可用。有没有其他方法可以在 pytorch
中模仿这个而不使用循环或将张量转换为 numpy
数组?
下面是演示:
import torch
t = torch.randn((32, 16, 60, 64)) # tensor
i0 = torch.randint(0, 32, (10, 1)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10, 1)).to(dtype=torch.long) # indexes for dim=2
i = torch.cat((i0, i2), 1) # indexes
v = torch.randn((10, 16, 64)) # to be assigned
# t[i0, :, i2, :] = v ?? Obviously this does not work
[1]
[2]
[3] https://numpy.org/doc/stable/reference/generated/numpy.s_.html
经过评论区的讨论,我们得出以下解决方案:
import torch
t = torch.randn((32, 16, 60, 64)) # tensor
# indices
i0 = torch.randint(0, 32, (10,)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10,)).to(dtype=torch.long) # indexes for dim=2
v = torch.randn((10, 16, 64)) # to be assigned
t[(i0, slice(None), i2, slice(None))] = v