在 pytorch 中,是否有内置方法来提取具有给定索引的行?

In pytorch, is there a built-in method to extract rows with given indexes?

假设我有一个火炬张量

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

和一个列表

b = [0,2]

是否有内置方法提取第 0 行和第 2 行并将它们放入新的张量中:

tensor([[1,2,3],
        [7,8,9]])

特别是有没有这样的函数:

extract_rows(a,b) -> c

其中 c 包含所需的行。当然,这可以通过 for 循环完成,但内置方法通常更快。

请注意,该示例只是示例,列表中可能有数十个索引,张量中可能有数百行。

看看火炬内置 index_select() 方法。这会对你有所帮助。 要么 您可以使用切片来做到这一点。

tensor = [[1,2,3],
            [4,5,6],
            [7,8,9]]

new_tensor = tensor[0::2]
print(new_tensor)

输出:

[[1, 2, 3], [7, 8, 9]]

只需a[b]就可以了

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])
b = [0,2]
a[b]
tensor([[1, 2, 3],
        [7, 8, 9]])