如何从 SkLearn 管道中绘制决策树?
How to plot the DecisionTree out of a SkLearn Pipeline?
所以我正在研究 SkLearn 管道中的决策树。该模型工作正常。但是,我无法绘制决策树。我不确定通过调用 .plot 方法使用哪个对象。
这是我创建决策树模型的代码:
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import (
OneHotEncoder, PowerTransformer, StandardScaler
)
# Build categorical preprocessor
categorical_cols = X.select_dtypes(include="object").columns.to_list()
categorical_pipe = make_pipeline(
OneHotEncoder(sparse=False, handle_unknown="ignore")
)
# Build numeric processor
to_log = ["SA13_peopleHH"]
to_scale = ["SA11_age"]
numeric_pipe_1 = make_pipeline(PowerTransformer())
numeric_pipe_2 = make_pipeline(StandardScaler())
# Full processor
full = ColumnTransformer(
transformers=[
("categorical", categorical_pipe, categorical_cols),
("power_transform", numeric_pipe_1, to_log),
("standardization", numeric_pipe_2, to_scale),
]
)
# Final pipeline combined with DecisionTree
pipeline = Pipeline(
steps=[
("preprocess", full),
(
"base",
DecisionTreeClassifier(),
),
]
)
# Fit
_ = pipeline.fit(X_train, y_train)
这就是我调用 .plot 函数的方式:
tree.plot_tree(pipeline)
来自:
我认为,tree.plot_tree(pipeline['base'])
可以
所以我正在研究 SkLearn 管道中的决策树。该模型工作正常。但是,我无法绘制决策树。我不确定通过调用 .plot 方法使用哪个对象。
这是我创建决策树模型的代码:
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import (
OneHotEncoder, PowerTransformer, StandardScaler
)
# Build categorical preprocessor
categorical_cols = X.select_dtypes(include="object").columns.to_list()
categorical_pipe = make_pipeline(
OneHotEncoder(sparse=False, handle_unknown="ignore")
)
# Build numeric processor
to_log = ["SA13_peopleHH"]
to_scale = ["SA11_age"]
numeric_pipe_1 = make_pipeline(PowerTransformer())
numeric_pipe_2 = make_pipeline(StandardScaler())
# Full processor
full = ColumnTransformer(
transformers=[
("categorical", categorical_pipe, categorical_cols),
("power_transform", numeric_pipe_1, to_log),
("standardization", numeric_pipe_2, to_scale),
]
)
# Final pipeline combined with DecisionTree
pipeline = Pipeline(
steps=[
("preprocess", full),
(
"base",
DecisionTreeClassifier(),
),
]
)
# Fit
_ = pipeline.fit(X_train, y_train)
这就是我调用 .plot 函数的方式:
tree.plot_tree(pipeline)
来自:
我认为,tree.plot_tree(pipeline['base'])
可以