Pyspark 的 Estimator 和 RandomForestClassifier 之间的关系是什么

What is the relation between Pyspark's Estimator and RandomForestClassifier

我目前正在与 pyspark.ml.classification.RandomForestClassifier 和 pyspark.ml.tuning.CrossValidator 合作。我显然可以使用 RandomForestClassifier 作为 CrossValidation 的“估计器”参数。但是 RandomForestClassifier 似乎并没有继承自 pyspark.ml.base.Estimator.

另一方面,查看 RandomForestClassifier (https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/classification.html#RandomForestClassifier) 的源代码,我无法弄清楚 RandomForestClassifier 在哪里实现它的 fit 方法(我认为即使它继承也应该发生来自 Estimator,因此当您调用 RandomForestClassifier.fit() 时,您将获得 RandomForestClassifier 实现)。

那么如何在 CrossValidator 中使用 RandomForestClassfier 作为“估计器”呢?这两者之间的关系是什么 类 RandomForestClassifier fit 方法在哪里实现?

根据您链接的来源,RandomForestClassifier 继承自 _JavaProbabilisticClassifier:

class RandomForestClassifier(_JavaProbabilisticClassifier, # ...

继承自ProbabilisticClassifier:

class _JavaProbabilisticClassifier(ProbabilisticClassifier, _JavaClassifier, # ...

来自 Classifier

class ProbabilisticClassifier(Classifier, # ...

来自Predictor

class Classifier(Predictor, # ...

最后,Estimator

class Predictor(Estimator, # ...

对于你的第二个问题,fit 方法定义在 JavaEstimator here:

class JavaEstimator(JavaParams, Estimator, metaclass=ABCMeta):

    def _fit_java(self, dataset):     
        self._transfer_params_to_java()
        return self._java_obj.fit(dataset._jdf)

    def _fit(self, dataset):
        java_model = self._fit_java(dataset)
        model = self._create_model(java_model)
        return self._copyValues(model)

其中 self._java_obj.fit 调用 Java/Scala 分类器对象的 fit 方法。

JavaPredictor_JavaClassifier继承,_JavaClassifier_JavaProbabilisticClassifier继承

class _JavaClassifier(Classifier, JavaPredictor, # ...

请注意 Estimator.fit 调用 _fit。请参阅 here 了解其定义。