如何可视化 sklearn GradientBoostingClassifier?
How to visualize an sklearn GradientBoostingClassifier?
我训练了 gradient boost classifier, and I would like to visualize it using the graphviz_exporter tool shown here。
当我尝试时,我得到:
AttributeError: 'GradientBoostingClassifier' object has no attribute 'tree_'
这是因为 graphviz_exporter 是为 decision trees 设计的,但我想仍然有办法将其可视化,因为梯度提升分类器必须有一个底层决策树。
有人知道怎么做吗?
属性估计器包含基础决策树。以下代码显示了经过训练的 GradientBoostingClassifier 的其中一棵树。请注意,尽管集成作为一个整体是一个分类器,但每个单独的树都会计算浮点值。
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
import numpy as np
# Ficticuous data
np.random.seed(0)
X = np.random.normal(0,1,(1000, 3))
y = X[:,0]+X[:,1]*X[:,2] > 0
# Classifier
clf = GradientBoostingClassifier(max_depth=3, random_state=0)
clf.fit(X[:600], y[:600])
# Get the tree number 42
sub_tree_42 = clf.estimators_[42, 0]
# Visualization
# Install graphviz: https://www.graphviz.org/download/
from pydotplus import graph_from_dot_data
from IPython.display import Image
dot_data = export_graphviz(
sub_tree_42,
out_file=None, filled=True, rounded=True,
special_characters=True,
proportion=False, impurity=False, # enable them if you want
)
graph = graph_from_dot_data(dot_data)
Image(graph.create_png())
42号树:
我训练了 gradient boost classifier, and I would like to visualize it using the graphviz_exporter tool shown here。
当我尝试时,我得到:
AttributeError: 'GradientBoostingClassifier' object has no attribute 'tree_'
这是因为 graphviz_exporter 是为 decision trees 设计的,但我想仍然有办法将其可视化,因为梯度提升分类器必须有一个底层决策树。
有人知道怎么做吗?
属性估计器包含基础决策树。以下代码显示了经过训练的 GradientBoostingClassifier 的其中一棵树。请注意,尽管集成作为一个整体是一个分类器,但每个单独的树都会计算浮点值。
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
import numpy as np
# Ficticuous data
np.random.seed(0)
X = np.random.normal(0,1,(1000, 3))
y = X[:,0]+X[:,1]*X[:,2] > 0
# Classifier
clf = GradientBoostingClassifier(max_depth=3, random_state=0)
clf.fit(X[:600], y[:600])
# Get the tree number 42
sub_tree_42 = clf.estimators_[42, 0]
# Visualization
# Install graphviz: https://www.graphviz.org/download/
from pydotplus import graph_from_dot_data
from IPython.display import Image
dot_data = export_graphviz(
sub_tree_42,
out_file=None, filled=True, rounded=True,
special_characters=True,
proportion=False, impurity=False, # enable them if you want
)
graph = graph_from_dot_data(dot_data)
Image(graph.create_png())
42号树: