• 决策树可视化方法与技巧汇总(1)(分类决策树)(含Python代码示例)


    决策树的可视化可以帮助我们非常直观的了解算法细节。但在具体使用过程中可能会遇到一些问题。以下是我整理的一些可视化方法:

    一、可视化工具Graphviz:

    Graphviz是一个开源的图(Graph)可视化软件,采用抽象的图和网络来表示结构化的信息。在数据科学领域,Graphviz的一个用途就是实现决策树可视化。

    1.使用export_graphviz 将树导出为 Graphviz 格式:

    from sklearn import tree
    from sklearn.datasets import load_wine
    
    wine = load_wine()
    clf = tree.DecisionTreeClassifier()
    clf.fit(wine.data, wine.target)
    
    with open("wine.dot", 'w') as f:
        tree.export_graphviz(clf, out_file=f)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    这里会生成一个纯文本文件iris.dot,我们可以直接打开查看,具体内容如下:

    digraph Tree {
    node [shape=box, fontname="helvetica"] ;
    edge [fontname="helvetica"] ;
    0 [label="X[12] <= 755.0\ngini = 0.658\nsamples = 178\nvalue = [59, 71, 48]"] ;
    1 [label="X[11] <= 2.115\ngini = 0.492\nsamples = 111\nvalue = [2, 67, 42]"] ;
    0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
    2 [label="X[10] <= 0.935\ngini = 0.227\nsamples = 46\nvalue = [0, 6, 40]"] ;
    1 -> 2 ;
    3 [label="X[6] <= 1.58\ngini = 0.049\nsamples = 40\nvalue = [0, 1, 39]"] ;
    2 -> 3 ;
    4 [label="gini = 0.0\nsamples = 39\nvalue = [0, 0, 39]"] ;
    3 -> 4 ;
    5 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]"] ;
    3 -> 5 ;
    6 [label="X[1] <= 2.395\ngini = 0.278\nsamples = 6\nvalue = [0, 5, 1]"] ;
    2 -> 6 ;
    7 [label="gini = 0.0\nsamples = 5\nvalue = [0, 5, 0]"] ;
    6 -> 7 ;
    8 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]"] ;
    6 -> 8 ;
    9 [label="X[6] <= 0.795\ngini = 0.117\nsamples = 65\nvalue = [2, 61, 2]"] ;
    1 -> 9 ;
    10 [label="gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]"] ;
    9 -> 10 ;
    11 [label="X[0] <= 13.175\ngini = 0.061\nsamples = 63\nvalue = [2, 61, 0]"] ;
    9 -> 11 ;
    12 [label="gini = 0.0\nsamples = 58\nvalue = [0, 58, 0]"] ;
    11 -> 12 ;
    13 [label="X[12] <= 655.0\ngini = 0.48\nsamples = 5\nvalue = [2, 3, 0]"] ;
    11 -> 13 ;
    14 [label="gini = 0.0\nsamples = 3\nvalue = [0, 3, 0]"] ;
    13 -> 14 ;
    15 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0, 0]"] ;
    13 -> 15 ;
    16 [label="X[6] <= 2.165\ngini = 0.265\nsamples = 67\nvalue = [57, 4, 6]"] ;
    0 -> 16 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
    17 [label="X[9] <= 3.605\ngini = 0.375\nsamples = 8\nvalue = [0, 2, 6]"] ;
    16 -> 17 ;
    18 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]"] ;
    17 -> 18 ;
    19 [label="gini = 0.0\nsamples = 6\nvalue = [0, 0, 6]"] ;
    17 -> 19 ;
    20 [label="X[9] <= 3.435\ngini = 0.065\nsamples = 59\nvalue = [57, 2, 0]"] ;
    16 -> 20 ;
    21 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]"] ;
    20 -> 21 ;
    22 [label="gini = 0.0\nsamples = 57\nvalue = [57, 0, 0]"] ;
    20 -> 22 ;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    这样的词典格式不方便我们理解,我们可视化:

    2.将.dot文件转换为可视化图形

    为了有更好的可视化效果,可以使用graphviz可执行包中的dot程序将其转化为可视化的PDF文档。

    具体方法为执行如下命令:

    dot -Tpdf wine.dot -o wine.pdf
    
    • 1

    转化完PDF后打开的图形如下:

    二、pydotplus生成PDF

    使用命令行非常的麻烦,可以采取的方式是安装pydotplus(pip install pydotplus)来生成PDF。

    另外在在使用tree.export_graphviz导出数据是还可以另外加一些参数,使得图片看起来更容易理解:

    from sklearn import tree
    from sklearn.datasets import load_wine
    import pydotplus
    
    wine = load_wine()
    clf = tree.DecisionTreeClassifier()
    clf.fit(wine.data, wine.target)
    
    dot_data = tree.export_graphviz(clf, out_file=None,
                                    feature_names=wine.feature_names,
                                    class_names=wine.target_names,
                                    filled=True, rounded=True,
                                    special_characters=True)
    
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_pdf('wine.pdf')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    生成的结果为:

    sklearn.tree.export_graphviz(decision_tree, out_file=None, *, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, special_characters=False, precision=3)
    
    • 1
    传入参数:
    decision_tree:决策树对象
    out_file:输出文件的句柄或名称。
    max_depth:数的最大深度
    feature_names:特征名称列表
    class_names:类别名称列表,按升序排序
    label:显示纯度信息的选项{all, ‘root’, ‘none’}
    filled:绘制节点以指示分类的多数类、回归值的极值或多输出的节点的纯度。
    leaves_parallel:在树的底部绘制所有叶节点。
    impurity:是否显示纯度显示
    node_ids:是否显示每个节点的ID号
    proportion:将“值”和 “样本量”的显示分别更改为比例。
    rotate:设置未True是从左往右绘制,False是从上往下绘制。
    rounded:设置未True时,使用圆角进行绘制。
    special_characters:设置为时False,忽略特殊字符以实现PostScrip兼容性
    precision:每个节点数值的精度
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    要是觉得生成PDF查看比较麻烦,可采取生成图片:

    from sklearn import tree
    from sklearn.datasets import load_wine
    import pydotplus
    
    wine = load_wine()
    clf = tree.DecisionTreeClassifier()
    clf.fit(wine.data, wine.target)
    
    dot_data = tree.export_graphviz(clf, out_file=None,
                                    feature_names=wine.feature_names,
                                    class_names=wine.target_names,
                                    filled=True, rounded=True,
                                    special_characters=True)
    
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_png("dtree.png")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    三、scikit-learn的tree.plot_tree

    from sklearn import tree
    from sklearn.datasets import load_wine
    import matplotlib.pyplot as plt
    
    wine = load_wine()
    clf = tree.DecisionTreeClassifier()
    clf.fit(wine.data, wine.target)
    
    plt.figure(figsize=(12,8))
    tree.plot_tree(clf,max_depth=10,feature_names=wine.feature_names,class_names=wine.target_names,filled=True, rounded=True)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    但是我们生成的图像分辨率较低!

    四、dtreeviz美化输出

    dtreeviz是一个美化输出的组件,在使用起来非常的简单:

    from sklearn.datasets import load_wine
    from dtreeviz.trees import dtreeviz
    wine = load_wine()
    clf = tree.DecisionTreeClassifier()
    clf.fit(wine.data, wine.target)
    
    viz = dtreeviz(clf,
                   x_data=wine.data,
                   y_data=wine.target,
                   target_name='class',
                   feature_names=wine.feature_names,
                   class_names=list(wine.target_names),
                   title="Decision Tree - wine data set")
    viz.save('dtreeviz.svg')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    我们生成的结果为:

    每个节点上,我们都可以看到用于分割观测值的特征的堆叠直方图,并按类别着色。通过这种方式,我们可以看到类是如何分割的。x轴的小三角形是拆分点。叶节点用饼图表示,饼图显示叶中的观察值属于哪个类。这样,我们就可以很容易地看到哪个类是最主要的,所以也可以看到模型的预测。我们也可以为测试集创建一个类似的可视化,只需要在调用函数时替换x_data和y_data参数。

    如果你不喜欢直方图并且希望简化绘图,可以指定fancy=False来接收以下简化绘图。


    dtreeviz的另一个方便的功能是提高模型的可解释性,即在绘图上突出显示特定观测值的路径。通过这种方式,我们可以清楚地看到哪些特征有助于类预测。使用下面的代码片段,我们突出显示测试集的第一个样本的路径。

    from sklearn import tree
    from sklearn.datasets import load_wine
    from dtreeviz.trees import dtreeviz
    wine = load_wine()
    clf = tree.DecisionTreeClassifier()
    clf.fit(wine.data, wine.target)
    
    viz = dtreeviz(clf,
                   x_data=wine.data,
                   y_data=wine.target,
                   target_name='class',
                   feature_names=wine.feature_names,
                   class_names=list(wine.target_names),
                   title="Decision Tree - Iris data set",
                   X=wine.data[0])
    viz.save('dtreeviz2.svg')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16


    这张图与前一张非常相似,然而,橙色突出清楚地显示了样本所遵循的路径。此外,我们可以在每个直方图上看到橙色三角形。它表示给定特征的观察值。我们还可以通过设置orientation=“LR”从上到下再从左到右更改绘图的方向。

    最后,我们可以打印这个观察预测所用的决定:

    print(explain_prediction_path(clf, wine.data[0], feature_names=wine.feature_names, explanation_type="plain_english"))
    
    • 1

    结果为:

    2.17 <= flavanoids 
    3.43 <= color_intensity 
    755.0 <= proline 
    
    • 1
    • 2
    • 3
  • 相关阅读:
    【C++】每周一题——2024.3.3
    深度学习(19):nerf论文公式理解
    @Tag和@Operation标签失效问题。SpringDoc 2.2.0(OpenApi 3)和Spring Boot 3.1.1集成
    PostGIS学习教程六:几何图形(geometry)
    从0到1做一个产品需要注意的基本点
    携程“919旅行囤货划算节”两年,已成行业超级IP
    数据结构学习笔记——插入排序
    阿里技术面总结
    实操客群分层|无监督训练与有监督评估,面试中这两大类风控模型最会被问到的问题
    题目:2715.执行可取消的延迟函数
  • 原文地址:https://blog.csdn.net/wzk4869/article/details/126248859