Pytorch Tensor::data_ptr<long long>() 无法在 Linux 上运行

Pytorch Tensor::data_ptr<long long>() not working on Linux

我无法link我的程序到Linux下的pytorch,得到以下错误:

/tmp/ccbgkLx2.o: In function `long long* at::Tensor::data<long long>() const':
test.cpp:(.text._ZNK2at6Tensor4dataIxEEPT_v[_ZNK2at6Tensor4dataIxEEPT_v]+0x14): undefined reference to `long long* at::Tensor::data_ptr<long long>() const'

我正在构建一个非常简单的最小示例:

#include "torch/script.h"
#include <iostream>

int main() {
    auto options = torch::TensorOptions().dtype(torch::kInt64);
    torch::NoGradGuard no_grad;
    auto T = torch::zeros(20, options).view({ 10, 2 });
    long long *data = (long long *)T.data<long long>();
    data[0] = 1;
    return 0;
}

用于构建它的命令:

g++ -w -std=c++17 -o test-torch test.cpp -D_GLIBCXX_USE_CXX11_ABI=1 -Wl,--whole-archive -ldl -lpthread -Wl,--no-whole-archive -I../libtorch/include -L../libtorch/lib -ltorch -ltorch_cpu -lc10 -Wl,-rpath,../libtorch/lib

Pytorch 已从 link https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcpu.zip 下载并解压缩(所以我在带有 test.cpp 的文件夹旁边有 libtorch 文件夹)。

有什么办法解决这个问题吗?同样的程序在 Visual C++ 下工作得很好。

P.S。我知道 pytorch 是为 cmake 设计的,但我对 cmake 的经验为零,也不想为我的应用程序编写 cmake-based 构建系统。此外,他们给出的示例似乎只有在系统中“安装”了 pytorch 时才有效。所以我不能只下载带有库的 .zip 文件吗?如果我在 AVX512 系统上“安装”它(例如,从源代码或以任何其他方式),我 link 到它并分发到 end-users 的二进制文件是否可以在非 AVX512 上工作?文档对新手来说完全看不懂

更新:我尝试按照教程 https://pytorch.org/cppdocs/installing.html 通过 CMake 执行此操作,但得到了完全相同的错误。具体来说,我将目录重命名为 example-app,并将源文件重命名为 example-app.cpp。然后我在这个目录下创建了CMakeLists.txt,内容如下:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

然后

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=../../libtorch ..
cmake --build . --config Release

这是输出:

CMakeFiles/example-app.dir/example-app.cpp.o: In function `long long* at::Tensor::data<long long>() const':
example-app.cpp:(.text._ZNK2at6Tensor4dataIxEEPT_v[_ZNK2at6Tensor4dataIxEEPT_v]+0x14): undefined reference to `long long* at::Tensor::data_ptr<long long>() const'

让我觉得,也许我忘了包含一些 header 或定义了一些变量? 哦,这都是在Mint 19.2(相当于Ubuntu18.04)上,g++版本是7.5.0,glibc是2.27。用 g++-8 编译得到相同的结果。

这不是与 cmake 相关的错误,而是库的实现方式。我不知道为什么,但 T* at::Tensor::data<T> constT = long long 的特化似乎是 forgotten/omitted.

如果你想得到带符号的 64 位指针,你仍然可以用 int64_t:

auto data = T.data<int64_t>();

一般来说,最好使用这些大小明确的类型,以避免出现兼容性问题。