Cython 中的继承和 std::shared_ptr

Inheritance and std::shared_ptr in Cython

假设我在 file.h 中有以下简单的 C++ 继承示例:

class Base {};
class Derived : public Base {};

然后,编译以下代码;也就是说,我可以将 std::shared_ptr<Derived> 分配给 std::shared_ptr<Base>:

Derived* foo = new Derived();
std::shared_ptr<Derived> shared_foo = std::make_shared<Derived>(*foo);
std::shared_ptr<Base> bar = shared_foo;

假设我已经将类型添加到 decl.pxd:

cdef extern from "file.h":
    cdef cppclass Base:
        pass
    cdef cppclass Derived(Base):
        pass

然后,我要做的是在 file.pyx:

中模仿 Cython 中的上述 C++ 赋值
cimport decl
from libcpp.memory cimport make_shared, shared_ptr

def do_stuff():
    cdef decl.Derived* foo = new decl.Derived()
    cdef shared_ptr[decl.Derived] shared_foo = make_shared[decl.Derived](foo)
    cdef shared_ptr[decl.Base] bar = shared_foo

与 C++ 的情况不同,这现在失败并出现以下错误(使用 Cython 3.0a6):

cdef shared_ptr[decl.Base] bar = shared_foo
                                ^
---------------------------------------------------------------

 Cannot assign type 'shared_ptr[Derived]' to 'shared_ptr[Base]'

我应该期待这种行为吗?有什么方法可以模仿 C++ 示例对 Cython 的作用吗?

编辑:比照。对下面接受的答案的评论,相关功能已添加到 Cython 中,并且从版本 3.0a7 开始可用。

我没有尝试过 Cyton,但是 std::shared_ptr 有一个静态转换函数 std::static_pointer_cast。我认为这会起作用

std::shared_ptr<Base> bar = std::static_pointer_cast<Base>(shared_foo);

.

def do_stuff():
    cdef decl.Derived* foo = new decl.Derived()
    cdef shared_ptr[decl.Derived] shared_foo = make_shared[decl.Derived](foo)
    cdef shared_ptr[decl.Base] bar = static_pointer_cast[decl.Base] shared_foo

旁注

您创建 shared_foo 的方式可能不是您想要的。在这里,您首先创建一个动态分配的 Derived。然后您正在创建一个新的动态分配的共享派生,它是原始副本。

// this allocates one Derived
Derived* foo = new Derived(); 
// This allocates a new copy, it does not take ownership of foo
std::shared_ptr<Derived> shared_foo = std::make_shared<Derived>(*foo); 

您可能想要的是:

Derived* foo = new Derived();
std::shared_ptr<Derived> shared_foo(foo); // This now takes ownership of foo

或者只是:

// This makes a default constructed shared Derived
auto shared_foo = std::make_shared<Derived>(); 

它应该适用于 Cython>=3.0,因为 @fuglede 做了这个 PR 解决了下面描述的问题(对于 Cython<3.0 仍然存在)。


问题是,std::shared_ptrwrapper 未命中

template <class U> shared_ptr& operator= (const shared_ptr<U>& x) noexcept;

std::shared_ptr-class.

如果你像那样修补包装器:

cdef extern from "<memory>" namespace "std" nogil:
cdef cppclass shared_ptr[T]:
    ...
    shared_ptr[T]& operator=[Y](const shared_ptr[Y]& ptr)
    #shared_ptr[Y](shared_ptr[Y]&)  isn't accepted

您的代码将编译。

您可能会问,为什么需要 operator= 而不是构造函数 shared_ptr[Y],因为:

...
cdef shared_ptr[decl.Base] bar = shared_foo

看起来构造函数 (template <class U> shared_ptr (const shared_ptr<U>& x) noexcept;) 不是显式的。但这是 Cython 与 C++ 的怪癖之一。上面的代码会被翻译成

std::shared_ptr<Base> __pyx_v_bar;
...
__pyx_v_bar = __pyx_v_shared_foo;

而不是

std::shared_ptr<Base> __pyx_v_bar = __pyx_v_shared_foo;

因此 Cython 将检查 operator= 的存在(对我们来说很幸运,因为 Cython 似乎不支持带模板的构造函数,但支持运算符)。


如果你想在没有打补丁的系统上分发你的模块 memory.pxd 你有两个选择:

  1. 自己正确包装std::shared_ptr
  2. 写一个小实用函数,例如
%%cython
...
cdef extern from *:
    """
    template<typename T1, typename T2>
    void assign_shared_ptr(std::shared_ptr<T1>& lhs, const std::shared_ptr<T2>& rhs){
         lhs = rhs;
    }
    """
    void assign_shared_ptr[T1, T2](shared_ptr[T1]& lhs, shared_ptr[T2]& rhs)
    
...
cdef shared_ptr[Derived] shared_foo
# cdef shared_ptr[decl.Base] bar = shared_foo
# must be replaced through:
cdef shared_ptr[Base] bar 
assign_shared_ptr(bar, shared_foo)
...

这两种选择都有缺点,因此根据您的情况,您可能更喜欢其中一种。