如何在随机森林中不显示 "samples" 和 "value" 的情况下绘制树?
How to plot tree without showing "samples" and "value" in random forest?
我想让我的树更简单,想知道如何在不显示样本(例如 83)和值(例如 [34,53,29,26])的情况下绘制树? (我不想要最后两行)
这是当前绘制树的部分代码。
X = df.iloc[:,0: -1]
y = df.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
clf = RandomForestClassifier()
clf.fit(X_train,y_train)
.
.
.
.
# Here, I guess I need to add some commands.
plot_tree(clf.estimators_[5],
feature_names=X.columns,
class_names=names,
filled=True,
impurity=True,
rounded=True,
max_depth = 3)
假设我们有这样一个数据集,我们使用 ax =
参数分配 matplotlib 轴:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
import matplotlib.pyplot as plt
import re
import matplotlib
fig, ax = plt.subplots(figsize=(8,5))
clf = RandomForestClassifier(random_state=0)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
tree.plot_tree(clf.estimators_[0],ax=ax,
feature_names= iris.feature_names, class_names=iris.target_names)
不确定这是否是最好的方法,一种方法是进入 ax.properties()
并编辑文本:
def replace_text(obj):
if type(obj) == matplotlib.text.Annotation:
txt = obj.get_text()
txt = re.sub("\nsamples[^$]*class","\nclass",txt)
obj.set_text(txt)
return obj
ax.properties()['children'] = [replace_text(i) for i in ax.properties()['children']]
fig.show()
@StupidWolf 对上述命题的小改进。
如果类很多,则value = [...]
拆分成多行:
value = [100, 0, 0, 0, 0,
6, 7, 0, 0, 0, 0,
0, 13]
因此,我没有用 re.sub(...)
替换文本,而是检查哪一行开始 value
部分:
def replace_text(obj):
if type(obj) == matplotlib.text.Annotation:
txt = obj.get_text()
_lines = txt.splitlines()
_result = []
value_index = None
class_index = None
for i, _line in enumerate(_lines):
if "value" in _line:
value_index = i
if "class" in _line:
class_index = i
assert value_index and class_index
_result = _lines[:value_index] + _lines[class_index:]
obj.set_text("\n".join(_result))
return obj
我想让我的树更简单,想知道如何在不显示样本(例如 83)和值(例如 [34,53,29,26])的情况下绘制树? (我不想要最后两行)
这是当前绘制树的部分代码。
X = df.iloc[:,0: -1]
y = df.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
clf = RandomForestClassifier()
clf.fit(X_train,y_train)
.
.
.
.
# Here, I guess I need to add some commands.
plot_tree(clf.estimators_[5],
feature_names=X.columns,
class_names=names,
filled=True,
impurity=True,
rounded=True,
max_depth = 3)
假设我们有这样一个数据集,我们使用 ax =
参数分配 matplotlib 轴:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
import matplotlib.pyplot as plt
import re
import matplotlib
fig, ax = plt.subplots(figsize=(8,5))
clf = RandomForestClassifier(random_state=0)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
tree.plot_tree(clf.estimators_[0],ax=ax,
feature_names= iris.feature_names, class_names=iris.target_names)
不确定这是否是最好的方法,一种方法是进入 ax.properties()
并编辑文本:
def replace_text(obj):
if type(obj) == matplotlib.text.Annotation:
txt = obj.get_text()
txt = re.sub("\nsamples[^$]*class","\nclass",txt)
obj.set_text(txt)
return obj
ax.properties()['children'] = [replace_text(i) for i in ax.properties()['children']]
fig.show()
@StupidWolf 对上述命题的小改进。
如果类很多,则value = [...]
拆分成多行:
value = [100, 0, 0, 0, 0,
6, 7, 0, 0, 0, 0,
0, 13]
因此,我没有用 re.sub(...)
替换文本,而是检查哪一行开始 value
部分:
def replace_text(obj):
if type(obj) == matplotlib.text.Annotation:
txt = obj.get_text()
_lines = txt.splitlines()
_result = []
value_index = None
class_index = None
for i, _line in enumerate(_lines):
if "value" in _line:
value_index = i
if "class" in _line:
class_index = i
assert value_index and class_index
_result = _lines[:value_index] + _lines[class_index:]
obj.set_text("\n".join(_result))
return obj