切片 pytorch 张量和 data_ptr() 的使用
Slicing pytorch tensors and use of data_ptr()
a = tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 8, 10],
[11, 12, 13, 14, 15]])
我有一个 torch
张量,我需要索引一个张量 c
这样 c = [[3], [8], [13]]
所以我做了 c = a[:,[2]]
这给了我预期的答案,但它在自动评分器上仍然失败。
autograder 使用如下检查功能 -
def check(orig, actual, expected):
expected = torch.tensor(expected)
same_elements = (actual == expected).all().item() == 1
same_storage = (orig.storage().data_ptr() == actual.storage().data_ptr())
return same_elements and same_storage
print('c correct:', check(a, c, [[3], [8], [13]]))
我试着调试了一下,发现same_storage
是假的,我不明白为什么orig.storage().data_ptr() == actual.storage().data_ptr()
应该是True
,以及它有什么不同。
更新
我能够通过 c = a[:, 2:3]
而不是 c = a[:, [2]]
得到正确答案,有什么区别?
PyTorch 允许张量成为现有张量的“视图”,这样它与其基础张量共享相同的底层数据,从而避免显式数据复制,从而能够执行快速且内存高效的操作。
中所述
When accessing the contents of a tensor via indexing, PyTorch follows Numpy behaviors that basic indexing returns views, while advanced indexing returns a copy.
在您的示例中,c = a[:, 2:3]
是基本索引,而 c = a[:, [2]]
是高级索引。这就是为什么只在第一种情况下创建视图的原因。因此,.storage().data_ptr()
给出相同的结果。
您可以在 Numpy indexing docs 中阅读有关基本和高级索引的信息。
Advanced indexing is triggered when the selection object, obj, is a non-tuple sequence object, an ndarray (of data type integer or bool), or a tuple with at least one sequence object or ndarray (of data type integer or bool).
a = tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 8, 10],
[11, 12, 13, 14, 15]])
我有一个 torch
张量,我需要索引一个张量 c
这样 c = [[3], [8], [13]]
所以我做了 c = a[:,[2]]
这给了我预期的答案,但它在自动评分器上仍然失败。
autograder 使用如下检查功能 -
def check(orig, actual, expected):
expected = torch.tensor(expected)
same_elements = (actual == expected).all().item() == 1
same_storage = (orig.storage().data_ptr() == actual.storage().data_ptr())
return same_elements and same_storage
print('c correct:', check(a, c, [[3], [8], [13]]))
我试着调试了一下,发现same_storage
是假的,我不明白为什么orig.storage().data_ptr() == actual.storage().data_ptr()
应该是True
,以及它有什么不同。
更新
我能够通过 c = a[:, 2:3]
而不是 c = a[:, [2]]
得到正确答案,有什么区别?
PyTorch 允许张量成为现有张量的“视图”,这样它与其基础张量共享相同的底层数据,从而避免显式数据复制,从而能够执行快速且内存高效的操作。
中所述When accessing the contents of a tensor via indexing, PyTorch follows Numpy behaviors that basic indexing returns views, while advanced indexing returns a copy.
在您的示例中,c = a[:, 2:3]
是基本索引,而 c = a[:, [2]]
是高级索引。这就是为什么只在第一种情况下创建视图的原因。因此,.storage().data_ptr()
给出相同的结果。
您可以在 Numpy indexing docs 中阅读有关基本和高级索引的信息。
Advanced indexing is triggered when the selection object, obj, is a non-tuple sequence object, an ndarray (of data type integer or bool), or a tuple with at least one sequence object or ndarray (of data type integer or bool).