如何在随机森林中不显示 "samples" 和 "value" 的情况下绘制树?

How to plot tree without showing "samples" and "value" in random forest?

我想让我的树更简单,想知道如何在不显示样本(例如 83)和值(例如 [34,53,29,26])的情况下绘制树? (我不想要最后两行)

这是当前绘制树的部分代码。

X = df.iloc[:,0: -1] 
y = df.iloc[:,-1]    
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
clf = RandomForestClassifier()
clf.fit(X_train,y_train)
.
.
.
.
# Here, I guess I need to add some commands.
plot_tree(clf.estimators_[5], 
          feature_names=X.columns,
          class_names=names, 
          filled=True, 
          impurity=True, 
          rounded=True,
          max_depth = 3)

假设我们有这样一个数据集,我们使用 ax = 参数分配 matplotlib 轴:

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
import matplotlib.pyplot as plt
import re
import matplotlib

fig, ax = plt.subplots(figsize=(8,5))

clf = RandomForestClassifier(random_state=0)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
tree.plot_tree(clf.estimators_[0],ax=ax,
feature_names= iris.feature_names, class_names=iris.target_names)

不确定这是否是最好的方法,一种方法是进入 ax.properties() 并编辑文本:

def replace_text(obj):
    if type(obj) == matplotlib.text.Annotation:
        txt = obj.get_text()
        txt = re.sub("\nsamples[^$]*class","\nclass",txt)
        obj.set_text(txt)
    return obj
    
ax.properties()['children'] = [replace_text(i) for i in ax.properties()['children']]
fig.show()

@StupidWolf 对上述命题的小改进。 如果类很多,则value = [...]拆分成多行:

value = [100, 0, 0, 0, 0, 
6, 7, 0, 0, 0, 0, 
0, 13] 

因此,我没有用 re.sub(...) 替换文本,而是检查哪一行开始 value 部分:

def replace_text(obj):
  if type(obj) == matplotlib.text.Annotation:
    txt = obj.get_text()
    _lines = txt.splitlines()
    _result = []
    value_index = None
    class_index = None
    for i, _line in enumerate(_lines):
      if "value" in _line:
        value_index = i
      if "class" in _line:
        class_index = i
    assert value_index and class_index
    _result = _lines[:value_index] + _lines[class_index:]
    obj.set_text("\n".join(_result))
  return obj