Select 基于索引的维度上的张量切片

Select tensor slice along a dimension based on index

我有一个形状如下的 PyTorch 张量:(100, 5, 100)。我需要通过从第二维的每一行中仅选择一个项目,将其转换为形状为 (100, 100) 的张量,这意味着在这 5 个元素中,我只需要一个及其对应的 100 个元素。

为了执行此操作,我有第二个形状为 (100,) 的张量,其索引指定应在每一行中选择这 5 个项目中的哪一个。

有没有一种简单的方法可以执行此选择而不必过多地弄乱尺寸?

假设张量的索引名为 idx,形状为 (100,)。值为 source 的张量。然后到 select:

result = source[torch.arange(100), idx]