如何在 PySpark 中获取估计器的所有参数
How to get all parameters of estimator in PySpark
我有一个 RandomForestRegressor
,GBTRegressor
,我想获取它们的所有参数。我发现它的唯一方法可以通过几种 get 方法来完成,例如:
from pyspark.ml.regression import RandomForestRegressor, GBTRegressor
est = RandomForestRegressor()
est.getMaxDepth()
est.getSeed()
但是 RandomForestRegressor
和 GBTRegressor
具有不同的参数,因此硬核所有这些方法并不是一个好主意。
解决方法可能是这样的:
get_methods = [method for method in dir(est) if method.startswith('get')]
params_est = {}
for method in get_methods:
try:
key = method[3:]
params_est[key] = getattr(est, method)()
except TypeError:
pass
那么输出将是这样的:
params_est
{'CacheNodeIds': False,
'CheckpointInterval': 10,
'FeatureSubsetStrategy': 'auto',
'FeaturesCol': 'features',
'Impurity': 'variance',
'LabelCol': 'label',
'MaxBins': 32,
'MaxDepth': 5,
'MaxMemoryInMB': 256,
'MinInfoGain': 0.0,
'MinInstancesPerNode': 1,
'NumTrees': 20,
'PredictionCol': 'prediction',
'Seed': None,
'SubsamplingRate': 1.0}
但我认为应该有更好的方法来做到这一点。
extractParamMap
可用于获取每个估算器的所有参数,例如:
>>> est = RandomForestRegressor()
>>> {param[0].name: param[1] for param in est.extractParamMap().items()}
{'numTrees': 20, 'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'minInstancesPerNode': 1, 'seed': -5851613654371098793, 'maxDepth': 5, 'featureSubsetStrategy': 'auto', 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'maxBins': 32}
>>> est = GBTRegressor()
>>> {param[0].name: param[1] for param in est.extractParamMap().items()}
{'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'stepSize': 0.1, 'minInstancesPerNode': 1, 'seed': -6363326153609583521, 'maxDepth': 5, 'maxIter': 20, 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'lossType': 'squared', 'maxBins': 32}
如中所述,您可以使用以下结构获取任何模型的原始 JVM 对象中可用的任何模型参数
<yourModel>.stages[<yourModelStage>]._java_obj.<getYourParameter>()
所有获取参数都在这里可用
https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/RandomForestClassificationModel.html
例如,如果您想在交叉验证后获取 RandomForest 的 MaxDepth(getMaxDepth 在 PySpark 中不可用),您可以使用
cvModel.bestModel.stages[-1]._java_obj.getMaxDepth()
我有一个 RandomForestRegressor
,GBTRegressor
,我想获取它们的所有参数。我发现它的唯一方法可以通过几种 get 方法来完成,例如:
from pyspark.ml.regression import RandomForestRegressor, GBTRegressor
est = RandomForestRegressor()
est.getMaxDepth()
est.getSeed()
但是 RandomForestRegressor
和 GBTRegressor
具有不同的参数,因此硬核所有这些方法并不是一个好主意。
解决方法可能是这样的:
get_methods = [method for method in dir(est) if method.startswith('get')]
params_est = {}
for method in get_methods:
try:
key = method[3:]
params_est[key] = getattr(est, method)()
except TypeError:
pass
那么输出将是这样的:
params_est
{'CacheNodeIds': False,
'CheckpointInterval': 10,
'FeatureSubsetStrategy': 'auto',
'FeaturesCol': 'features',
'Impurity': 'variance',
'LabelCol': 'label',
'MaxBins': 32,
'MaxDepth': 5,
'MaxMemoryInMB': 256,
'MinInfoGain': 0.0,
'MinInstancesPerNode': 1,
'NumTrees': 20,
'PredictionCol': 'prediction',
'Seed': None,
'SubsamplingRate': 1.0}
但我认为应该有更好的方法来做到这一点。
extractParamMap
可用于获取每个估算器的所有参数,例如:
>>> est = RandomForestRegressor()
>>> {param[0].name: param[1] for param in est.extractParamMap().items()}
{'numTrees': 20, 'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'minInstancesPerNode': 1, 'seed': -5851613654371098793, 'maxDepth': 5, 'featureSubsetStrategy': 'auto', 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'maxBins': 32}
>>> est = GBTRegressor()
>>> {param[0].name: param[1] for param in est.extractParamMap().items()}
{'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'stepSize': 0.1, 'minInstancesPerNode': 1, 'seed': -6363326153609583521, 'maxDepth': 5, 'maxIter': 20, 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'lossType': 'squared', 'maxBins': 32}
如
<yourModel>.stages[<yourModelStage>]._java_obj.<getYourParameter>()
所有获取参数都在这里可用 https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/RandomForestClassificationModel.html
例如,如果您想在交叉验证后获取 RandomForest 的 MaxDepth(getMaxDepth 在 PySpark 中不可用),您可以使用
cvModel.bestModel.stages[-1]._java_obj.getMaxDepth()