sklearn.tree.export_graphviz 备选方案

sklearn.tree.export_graphviz alternatives

可以使用 pypi 中的 pydotplus 可视化决策树,但它在我的机器上有问题(它说它不是用 libexpat 构建的,因此它只在节点上显示数字而不是table 和一些信息),我想使用替代方法。我已经尝试使用 networkx,但它需要 pygraphviz 才能读取 .dot 文件并制作它们的 networkx 图。当我尝试使用 pip 安装它时也失败了。

所以现在我正在寻找一种可视化决策树的替代方法,可以使用 pip 或 anaconda 安装。

存在哪些替代方案?

编辑#1

conda list 的输出:

# packages in environment at /home/xiaolong/development/anaconda3/envs/coursera_ml_classification:
#
alabaster                 0.7.7                    py34_0    defaults
awscli                    1.6.2                     <pip>
babel                     2.3.3                    py34_0    defaults
backports                 1.0                      py34_0    defaults
backports-abc             0.4                       <pip>
backports.shutil-get-terminal-size 1.0.0                     <pip>
backports_abc             0.4                      py34_0    defaults
bcdoc                     0.12.2                    <pip>
boto                      2.33.0                    <pip>
botocore                  0.73.0                    <pip>
cairo                     1.12.18                       6    defaults
certifi                   2015.4.28                 <pip>
colorama                  0.2.5                     <pip>
cycler                    0.10.0                   py34_0    defaults
decorator                 4.0.9                    py34_0    defaults
docutils                  0.12                     py34_0    defaults
entrypoints               0.2                      py34_1    defaults
expat                     2.1.0                         0    defaults
fontconfig                2.11.1                        5    defaults
freetype                  2.5.5                         0    defaults
get_terminal_size         1.0.0                    py34_0    defaults
glib                      2.43.0                        2    asmeurer
graphviz                  2.38.0                        1    defaults
harfbuzz                  0.9.39                        0    defaults
imagesize                 0.7.0                    py34_0    defaults
ipykernel                 4.3.1                    py34_0    defaults
ipython                   4.2.0                    py34_0    defaults
ipython-genutils          0.1.0                     <pip>
ipython_genutils          0.1.0                    py34_0    defaults
ipywidgets                4.1.1                    py34_0    defaults
jedi                      0.9.0                    py34_0    defaults
jinja2                    2.8                      py34_0    defaults
jmespath                  0.5.0                     <pip>
jsonschema                2.5.1                    py34_0    defaults
jupyter                   1.0.0                    py34_2    defaults
jupyter-client            4.2.2                     <pip>
jupyter-console           4.1.1                     <pip>
jupyter-core              4.1.0                     <pip>
jupyter_client            4.2.2                    py34_0    defaults
jupyter_console           4.1.1                    py34_0    defaults
jupyter_core              4.1.0                    py34_0    defaults
libffi                    3.2.1                         0    defaults
libgcc                    5.2.0                         0    defaults
libgfortran               3.0.0                         1    defaults
libpng                    1.6.17                        0    defaults
libsodium                 1.0.3                         0    defaults
libxml2                   2.9.2                         0    defaults
llvmlite                  0.10.0                   py34_0    defaults
markupsafe                0.23                     py34_0    defaults
matplotlib                1.5.1               np111py34_0    defaults
mistune                   0.7.2                    py34_0    defaults
mkl                       11.3.1                        0    defaults
multipledispatch          0.4.8                     <pip>
nbconvert                 4.2.0                    py34_0    defaults
nbformat                  4.0.1                    py34_0    defaults
notebook                  4.2.0                    py34_0    defaults
numpy                     1.11.0                   py34_0    defaults
openssl                   1.0.2h                        0    defaults
pandas                    0.18.1              np111py34_0    defaults
pango                     1.39.0                        0    defaults
path.py                   8.2.1                    py34_0    defaults
pep8                      1.7.0                    py34_0    defaults
pexpect                   4.0.1                    py34_0    defaults
pickleshare               0.5                      py34_0    defaults
pip                       8.1.1                    py34_1    defaults
pixman                    0.32.6                        0    defaults
prettytable               0.7.2                     <pip>
psutil                    4.1.0                    py34_0    defaults
ptyprocess                0.5                      py34_0    defaults
pyasn1                    0.1.9                     <pip>
pydotplus                 2.0.2                    py34_0    file:///home/xiaolong/development/anaconda3/conda-bld/linux-64/pydotplus-2.0.2-py34_0.tar.bz2
pyflakes                  1.1.0                    py34_0    defaults
pygments                  2.1.3                    py34_0    defaults
pyparsing                 2.1.1                    py34_0    defaults
pyqt                      4.11.4                   py34_1    defaults
python                    3.4.4                         0    defaults
python-contrib-nbextensions alpha                     <pip>
python-dateutil           2.5.2                    py34_0    defaults
pytz                      2016.3                   py34_0    defaults
pyyaml                    3.11                      <pip>
pyzmq                     15.2.0                   py34_0    defaults
qt                        4.8.7                         1    defaults
qtconsole                 4.2.1                    py34_0    defaults
readline                  6.2                           2    defaults
requests                  2.9.1                     <pip>
rope                      0.9.4                    py34_1    defaults
rope-py3k                 0.9.4.post1               <pip>
rsa                       3.1.2                     <pip>
scikit-learn              0.17.1              np111py34_0    defaults
scipy                     0.17.0              np111py34_3    defaults
setuptools                20.7.0                   py34_0    defaults
sframe                    1.8.5                     <pip>
simplegeneric             0.8.1                    py34_0    defaults
sip                       4.16.9                   py34_0    defaults
six                       1.10.0                   py34_0    defaults
snowballstemmer           1.2.1                    py34_0    defaults
sphinx                    1.4.1                    py34_0    defaults
sphinx-rtd-theme          0.1.9                     <pip>
sphinx_rtd_theme          0.1.9                    py34_0    defaults
spyder                    2.3.8                    py34_1    defaults
sqlite                    3.9.2                         0    defaults
terminado                 0.5                      py34_1    defaults
tk                        8.5.18                        0    defaults
tornado                   4.3                      py34_0    defaults
traitlets                 4.2.1                    py34_0    defaults
wheel                     0.29.0                   py34_0    defaults
xz                        5.0.5                         1    defaults
zeromq                    4.1.3                         0    defaults
zlib                      1.2.8                         0    defaults

SciPy版本:0.17.0

digraph Tree {
node [shape=box, style="filled", color="black"] ;
0 [label="grade.B <= 0.5\ngini = 0.5\nsamples = 37224\nvalue = [18476, 18748]", fillcolor="#399de504"] ;
1 [label="grade.C <= 0.5\ngini = 0.4973\nsamples = 32094\nvalue = [17218, 14876]", fillcolor="#e5813923"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.4829\nsamples = 21728\nvalue = [12875, 8853]", fillcolor="#e5813950"] ;
1 -> 2 ;
3 [label="gini = 0.4869\nsamples = 10366\nvalue = [4343, 6023]", fillcolor="#399de547"] ;
1 -> 3 ;
4 [label="grade.A <= 14.8301\ngini = 0.3702\nsamples = 5130\nvalue = [1258, 3872]", fillcolor="#399de5ac"] ;
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
5 [label="gini = 0.3555\nsamples = 4987\nvalue = [1153, 3834]", fillcolor="#399de5b2"] ;
4 -> 5 ;
6 [label="gini = 0.3902\nsamples = 143\nvalue = [105, 38]", fillcolor="#e58139a3"] ;
4 -> 6 ;
}

编辑#2

我在 Jupyter notebook 中对此进行了编程,但如果您尝试使用以下方法显示 SVG,则存在无法为 svg 着色的错误:

![Decision Tree]('dtree.svg')

我找到了解决方法here

from IPython.display import HTML

svg = None
with open('dtree.svg') as svg_file:
    svg = svg_file.read()

HTML(svg)

这不是最性感的解决方案,但我使用通过 subprocess 调用的 Grapviz CLI(称为 dot),我在 Mac,所以我用自制软件安装了它,但您可以下载其他平台的二进制文件 from their downloads page。这是一个使用 Titanic 数据集的示例:

import pandas as pd
import subprocess
import seaborn.apionly as sns
fromwd sklearn.preprocessing import Imputer
from sklearn.tree import DecisionTreeClassifier, export_graphviz

raw_data = sns.load_dataset('titanic')
predictors = ['pclass','sex','age','sibsp','parch','fare','embarked','alone','adult_male']
categorical = ['sex','embarked']
numeric = [c for c in predictors if c not in categorical]
target='survived'

encoded_data = pd.get_dummies(raw_data[predictors], columns=categorical)

imputer = Imputer()
X = imputer.fit_transform(encoded_data).astype('float32')
Y = raw_data[target].astype('float32')

model = DecisionTreeClassifier(min_samples_leaf=10, max_depth=3)
model.fit(X, Y)

export_graphviz(model,
                out_file='tree.dot',
                feature_names=encoded_data.columns,
                proportion=True,
                filled=True,
                impurity=False)

subprocess.call(['dot', '-Tpdf', 'tree.dot', '-o' 'tree.pdf'])

从版本 0.21 开始,scikit-learn 有 plot_tree 使用 matplotlib 绘制树的方法。

要使用的代码 plot_tree:

from sklearn import tree
# the clf is Decision Tree object
tree.plot_tree(clf,feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

sklearn 图的替代方案可以是 dtreeviz 包。树的例子如下。使用的代码 dtreeviz:

from dtreeviz.trees import dtreeviz # remember to load the package
# the clf is Decision Tree object
viz = dtreeviz(clf, X, y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=list(iris.target_names))

viz

您可以找到不同 scikit-learn 树绘图技术的比较 here