我如何使用 pybind11 .so 在所有 PyTorch 中 link?

How do I link in all of PyTorch with a pybind11 .so?

我有一个使用 pytorch c++ 的 pybind11 c++ 项目 api:

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <math.h>
#include <torch/torch.h>

...

void f()
{
...
   torch::Tensor dynamic_parameters = torch::full({1}, /*value=*/0.5, torch::dtype(torch::kFloat64).requires_grad(true));
   torch::optim::SGD optimizer({dynamic_parameters}, /*lr=*/0.01);
...
}

PYBIND11_MODULE(reson8, m)
{
    m.def("my_function", &my_function, "");
}

我使用 distutils 将其编译成一个 .so 可以导入 Python:

from distutils.core import setup, Extension

def configuration(parent_package='', top_path=None):
      import numpy
      from numpy.distutils.misc_util import Configuration
      from numpy.distutils.misc_util import get_info

      #Necessary for the half-float d-type.
      info = get_info('npymath')

      config = Configuration('',
                             parent_package,
                             top_path)
      config.add_extension('reson8',
                           ['reson8.cpp'],
                           extra_info=info,
                           include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include",
                                          "/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include",
                                          "/home/ian/dev/hedgey/Engine/lib/libtorch/include",
                                          "/home/ian/dev/hedgey/Engine/lib/libtorch/include/torch/csrc/api/include"])

      return config


if __name__ == "__main__":
      from numpy.distutils.core import setup
      setup(configuration=configuration)

它编译没有错误,但是在 python 中的 运行“import reson8”我得到这个错误:

importerror: undefined symbol: _ZTVN5torch5optim9OptimizerE

我不确定是不是 pytorch 没有链接到我的 so(虽然 .so 是 10mb,如果不包含 pytorch 就相当大了,但也许所有 pybind11 .so 文件都很大。 )

我该如何解决这个问题?

我最终发现我需要使用 Anaconda 版本的 torchlib 而不是我自己的,以及使用 Torch 的 CppExtension。这是我的作品 setup.py:

from distutils.core import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CppExtension

def configuration(parent_package='', top_path=None):
      import numpy
      from numpy.distutils.misc_util import Configuration
      from numpy.distutils.misc_util import get_info

      #Necessary for the half-float d-type.
      info = get_info('npymath')

      config = Configuration('',
                             parent_package,
                             top_path)

      config.ext_modules.append(CppExtension(
                name='reson8',
                sources=['reson8.cpp'],
                extra_info=info,
                extra_compile_args=['-g', '-D_GLIBCXX_USE_CXX11_ABI=0'],
                extra_ldflags=['-ltorch_python'],
                include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include",
                                          "/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include",
                                          "/home/ian/anaconda3/lib"
                                          ]
                                ))

      return config


if __name__ == "__main__":
      from numpy.distutils.core import setup
      setup(configuration=configuration)