为什么在与从文件读取的相同张量进行比较时,C++ PyTorch API 中的“is_same”会失败?

Why does `is_same` in the C++ PyTorch API fail when comparing with the same tensor that is read from a file?

为什么 torch::Tensor::is_same 以下断言失败?使用 C++ PyTorch API 将张量写入文件,然后再次读入另一个张量,is_same 比较两个张量:

torch::Tensor x_sequence = torch::linspace(0, M_PI, 1000);    
torch::save(x_sequence, "x_sequence.dat");
torch::Tensor x_read;
torch::load(x_read, "x_sequence.dat");
assert(x_read.is_same(x_sequence));  

这导致:

int main(int, char**): Assertion `x_read.is_same(x_sequence)' failed.

使用

torch::Tensor::is_same(const torch::Tensor& other) 定义为 here。重要的是要注意 Tensor 实际上是底层 TensorImpl class 上的指针(实际上保存数据)。

因此,当您调用 is_same 时,检查的实际上是您的指针是否相同,即您的 2 个张量是否指向相同的底层内存。这是一个非常简单的例子,很好理解:

auto x = torch::randn({4,4});
auto copy = x;
auto clone = x.clone();
std::cout << x.is_same(copy) << " " << x.is_same(clone) << std::endl;
>>> 0 1

在这里,对clone的调用强制pytorch将数据复制到另一个内存位置。因此,指针不同并且 is_same returns false.

如果你想真正比较这些值,你别无选择,只能计算两个张量之间的差异,并计算这个差异有多接近 0。