• 【阿旭机器学习实战】【25】决策树模型----树叶分类实战


    【阿旭机器学习实战】系列文章主要介绍机器学习的各种算法模型及其实战案例,欢迎点赞,关注共同学习交流。

    本文通过构建决策树模型,对某树叶分类数据集进行建模预测,并进行模型优化。

    决策树进行树叶分类实战

    1. 导入数据

    import pandas as pd
    import matplotlib.pyplot as plt
    
    from sklearn.preprocessing import LabelEncoder
    from sklearn.model_selection import train_test_split
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection  import GridSearchCV
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    data = pd.read_csv('train.csv')
    
    • 1
    data.head()
    
    • 1
    idspeciesmargin1margin2margin3margin4margin5margin6margin7margin8...texture55texture56texture57texture58texture59texture60texture61texture62texture63texture64
    01Acer_Opalus0.0078120.0234380.0234380.0039060.0117190.0097660.0273440.0...0.0078120.0000000.0029300.0029300.0351560.00.00.0048830.0000000.025391
    12Pterocarya_Stenoptera0.0058590.0000000.0312500.0156250.0253910.0019530.0195310.0...0.0009770.0000000.0000000.0009770.0234380.00.00.0009770.0390620.022461
    23Quercus_Hartwissiana0.0058590.0097660.0195310.0078120.0039060.0058590.0683590.0...0.1543000.0000000.0058590.0009770.0078120.00.00.0000000.0205080.002930
    35Tilia_Tomentosa0.0000000.0039060.0234380.0058590.0214840.0195310.0234380.0...0.0000000.0009770.0000000.0000000.0205080.00.00.0175780.0000000.047852
    46Quercus_Variabilis0.0058590.0039060.0488280.0097660.0136720.0156250.0058590.0...0.0966800.0000000.0214840.0000000.0000000.00.00.0000000.0000000.031250

    5 rows × 194 columns

    数据说明:
    species类别,64个margin边缘特征,64个shape形状特征,64个texture质感特征

    一共有99个树叶类别

    data.shape
    
    • 1
    (990, 194)
    
    • 1
    # 查看树叶类别数
    len(data.species.unique())
    
    • 1
    • 2
    99
    
    • 1

    2. 特征工程

    # 把字符串类别转化为数字形式
    lb = LabelEncoder().fit(data.species) 
    labels = lb.transform(data.species)    
    # 去掉'species', 'id'这两列对于训练模型无用的列
    data = data.drop(['species', 'id'], axis=1)  
    data.head()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    margin1margin2margin3margin4margin5margin6margin7margin8margin9margin10...texture55texture56texture57texture58texture59texture60texture61texture62texture63texture64
    00.0078120.0234380.0234380.0039060.0117190.0097660.0273440.00.0019530.033203...0.0078120.0000000.0029300.0029300.0351560.00.00.0048830.0000000.025391
    10.0058590.0000000.0312500.0156250.0253910.0019530.0195310.00.0000000.007812...0.0009770.0000000.0000000.0009770.0234380.00.00.0009770.0390620.022461
    20.0058590.0097660.0195310.0078120.0039060.0058590.0683590.00.0000000.044922...0.1543000.0000000.0058590.0009770.0078120.00.00.0000000.0205080.002930
    30.0000000.0039060.0234380.0058590.0214840.0195310.0234380.00.0136720.017578...0.0000000.0009770.0000000.0000000.0205080.00.00.0175780.0000000.047852
    40.0058590.0039060.0488280.0097660.0136720.0156250.0058590.00.0000000.005859...0.0966800.0000000.0214840.0000000.0000000.00.00.0000000.0000000.031250

    5 rows × 192 columns

    labels[:5]
    
    • 1
    array([ 3, 49, 65, 94, 84], dtype=int64)
    
    • 1
    # 切分数据集
    x_train,x_test,y_train,y_test = train_test_split(data, labels, test_size=0.2, stratify=labels)
    
    • 1
    • 2

    3. 构建决策树模型

    tree = DecisionTreeClassifier()
    tree.fit(x_train, y_train)
    
    • 1
    • 2
    DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                max_features=None, max_leaf_nodes=None,
                min_impurity_decrease=0.0, min_impurity_split=None,
                min_samples_leaf=1, min_samples_split=2,
                min_weight_fraction_leaf=0.0, presort=False, random_state=None,
                splitter='best')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    tree.score(x_test, y_test)
    
    • 1
    0.6767676767676768
    
    • 1
    tree.score(x_train, y_train)
    
    • 1
    1.0
    
    • 1

    结果表明该模型在训练集准确率为100%,而在测试集准确率仅有67%,存在过拟合现象,模型需要进一步优化。

    4. 模型优化

    # max_depth:树的最大深度
    # min_samples_split:内部节点再划分所需最小样本数
    # min_samples_leaf:叶子节点最少样本数
    param_grid = {'max_depth': [10,15,20,25,30],
                        'min_samples_split': [2,3,4,5,6,7,8],
                        'min_samples_leaf':[1,2,3,4,5,6,7]}
    # 网格搜索
    model = GridSearchCV(tree, param_grid, cv=3)
    model.fit(x_train, y_train)
    print(model.best_estimator_)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=30,
                max_features=None, max_leaf_nodes=None,
                min_impurity_decrease=0.0, min_impurity_split=None,
                min_samples_leaf=4, min_samples_split=5,
                min_weight_fraction_leaf=0.0, presort=False, random_state=None,
                splitter='best')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    model.score(x_train, y_train)
    
    • 1
    0.9444444444444444
    
    • 1
    model.score(x_test, y_test)
    
    • 1
    0.6868686868686869
    
    • 1

    如果内容对你有帮助,感谢点赞+关注哦!

    欢迎关注我的公众号:阿旭算法与机器学习,共同学习交流。
    更多干货内容持续更新中…

  • 相关阅读:
    【Java进阶篇】第三章 常用类
    【微信小程序】网络请求
    【ESP32 DEVKIT_V1】基于Arduino IDE环境搭建
    为什么分布式系统这么火?
    ArcGIS Engine基础(29)之加载arcgis server切片地图服务
    Kafka3.x核心速查手册三、服务端原理篇-2、Broker选举机制
    掌握这五点建议,Linux学习不再难
    debian 9 ssh root权限登录
    【Try to Hack】vulnhub narak
    Spring项目配置
  • 原文地址:https://blog.csdn.net/qq_42589613/article/details/127770111