__cinit__() 在扩展 Cython class 时正好采用 2 个位置参数

__cinit__() takes exactly 2 positional arguments when extending a Cython class

我想扩展 scikit-learn 的 ClassificationCriterion class,它在内部模块 sklearn.tree._criterion 中定义为 Cython class。我想在 Python 中这样做,因为通常我无法访问 sklearn 的 pyx/pxd 文件(所以我不能 cimport 它们)。但是,当我尝试扩展 ClassificationCriterion 时,出现错误 TypeError: __cinit__() takes exactly 2 positional arguments (0 given)。下面的 MWE 重现了错误,并表明错误发生在 __new__ 之后但在 __init__.

之前

有没有办法像这样扩展 Cython class?

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree._criterion import ClassificationCriterion

class MaxChildPrecision(ClassificationCriterion):
    def __new__(self, *args, **kwargs):
        print('new')
        super().__new__(MaxChildPrecision, *args, **kwargs)

    def __init__(self, *args, **kwargs):
        print('init')
        super(MaxChildPrecision).__init__(*args, **kwargs)

clf = DecisionTreeClassifier(criterion=MaxChildPrecision())

有两个问题。首先,ClassificationCriterion requires two specific arguments to its constructor that you aren't passing it。您将必须弄清楚这些参数代表什么并将它们传递给基数 class.

其次,有一个 Cython 问题。如果我们查看 the description of how to use __cinit__ 那么我们会看到:

Any arguments passed to the constructor will be passed to both the __cinit__() method and the __init__() method. If you anticipate subclassing your extension type in Python, you may find it useful to give the __cinit__() method * and ** arguments so that it can accept and ignore extra arguments. Otherwise, any Python subclass which has an init() with a different signature will have to override __new__() as well as __init__()

不幸的是,sklearn 的作者没有提供 *** 参数,因此您必须覆盖 __new__。这样的事情应该有效:

class MaxChildPrecision(ClassificationCriterion):
    def __init__(self,*args, **kwargs):
        pass

    def __new__(cls,*args,**kwargs):
        # I have NO IDEA if these arguments make sense!
        return super().__new__(cls,n_outputs=5,
                           n_classes=np.ones((2,),dtype=np.int))

我在 __new__ 中将必要的参数传递给 ClassificationCriterion 并在我认为合适的情况下在 __init__ 中处理其余部分。我不需要调用基数 class __init__(因为基数 class 没有定义 __init__)。