切片 pytorch 张量和 data_ptr() 的使用

Slicing pytorch tensors and use of data_ptr()

a = tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  8, 10],
        [11, 12, 13, 14, 15]])

我有一个 torch 张量,我需要索引一个张量 c 这样 c = [[3], [8], [13]]

所以我做了 c = a[:,[2]] 这给了我预期的答案,但它在自动评分器上仍然失败。 autograder 使用如下检查功能 -

def check(orig, actual, expected):
  expected = torch.tensor(expected)
  same_elements = (actual == expected).all().item() == 1
  same_storage = (orig.storage().data_ptr() == actual.storage().data_ptr())
  return same_elements and same_storage

print('c correct:', check(a, c, [[3], [8], [13]]))

我试着调试了一下,发现same_storage是假的,我不明白为什么orig.storage().data_ptr() == actual.storage().data_ptr()应该是True,以及它有什么不同。

更新 我能够通过 c = a[:, 2:3] 而不是 c = a[:, [2]] 得到正确答案,有什么区别?

PyTorch 允许张量成为现有张量的“视图”,这样它与其基础张量共享相同的底层数据,从而避免显式数据复制,从而能够执行快速且内存高效的操作。

Tensor View docs

中所述

When accessing the contents of a tensor via indexing, PyTorch follows Numpy behaviors that basic indexing returns views, while advanced indexing returns a copy.

在您的示例中,c = a[:, 2:3] 是基本索引,而 c = a[:, [2]] 是高级索引。这就是为什么只在第一种情况下创建视图的原因。因此,.storage().data_ptr() 给出相同的结果。

您可以在 Numpy indexing docs 中阅读有关基本和高级索引的信息。

Advanced indexing is triggered when the selection object, obj, is a non-tuple sequence object, an ndarray (of data type integer or bool), or a tuple with at least one sequence object or ndarray (of data type integer or bool).