• 四、分类算法 - 朴素贝叶斯算法


    目录

    1、朴素贝叶斯算法

    1.1 案例

    1.2 联合概率、条件概率、相互独立

    1.3 贝叶斯公式

    1.4 朴素贝叶斯算法原理

    1.5 应用场景

    2、朴素贝叶斯算法对文本进行分类

    2.1 案例

    2.2 拉普拉斯平滑系数

    3、API

    4、案例:20类新闻分类

    4.1 步骤分析

    4.2 代码分析

    5、总结


    1. sklearn转换器和估算器
    2. KNN算法
    3. 模型选择和调优
    4. 朴素贝叶斯算法
    5. 决策树
    6. 随机森林

    1、朴素贝叶斯算法

    朴素?

    假设:特征与特征之间是相互独立的

    1.1 案例

    1.2 联合概率、条件概率、相互独立

    1.3 贝叶斯公式

    1.4 朴素贝叶斯算法原理

    朴素 + 贝叶斯

    1.5 应用场景

    • 文本分类(单词作为特征)

    2、朴素贝叶斯算法对文本进行分类

    2.1 案例

    2.2 拉普拉斯平滑系数

    3、API

    4、案例:20类新闻分类

    4.1 步骤分析

    • 获取数据
    • 划分数据集
    • 特征工程  --文本特征抽取
    • 朴素贝叶斯预估器流程
    • 模型评估

    4.2 代码分析

    1. from sklearn.datasets import load_iris, fetch_20newsgroups
    2. from sklearn.feature_extraction.text import TfidfVectorizer
    3. from sklearn.model_selection import train_test_split, GridSearchCV
    4. from sklearn.naive_bayes import MultinomialNB
    5. from sklearn.neighbors import KNeighborsClassifier
    6. from sklearn.preprocessing import StandardScaler
    7. def knn_iris():
    8. # 用KNN 算法对鸢尾花进行分类
    9. # 1、获取数据
    10. iris = load_iris()
    11. # 2、划分数据集
    12. x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=6)
    13. # 3、特征工程 - 标准化
    14. transfer = StandardScaler()
    15. x_train = transfer.fit_transform(x_train)
    16. x_test = transfer.transform(x_test)
    17. # 4、KNN 算法预估器
    18. estimator = KNeighborsClassifier(n_neighbors=3)
    19. estimator.fit(x_train,y_train)
    20. # 5、模型评估
    21. # 方法1 :直接比对真实值和预测值
    22. y_predict = estimator.predict(x_test)
    23. print("y_predict:\n",y_predict)
    24. print("直接比对真实值和预测值:\n",y_test == y_predict)
    25. # 方法2:计算准确率
    26. score = estimator.score(x_test,y_test)
    27. print("准确率为:\n",score)
    28. return None
    29. def knn_iris_gscv():
    30. # 用KNN 算法对鸢尾花进行分类,添加网格搜索和交叉验证
    31. # 1、获取数据
    32. iris = load_iris()
    33. # 2、划分数据集
    34. x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=6)
    35. # 3、特征工程 - 标准化
    36. transfer = StandardScaler()
    37. x_train = transfer.fit_transform(x_train)
    38. x_test = transfer.transform(x_test)
    39. # 4、KNN 算法预估器
    40. estimator = KNeighborsClassifier()
    41. # 加入网格搜索和交叉验证
    42. # 参数准备
    43. param_dict = {"n_neighbors":[1,3,5,7,9,11]}
    44. estimator = GridSearchCV(estimator,param_grid=param_dict,cv=10)
    45. estimator.fit(x_train,y_train)
    46. # 5、模型评估
    47. # 方法1 :直接比对真实值和预测值
    48. y_predict = estimator.predict(x_test)
    49. print("y_predict:\n",y_predict)
    50. print("直接比对真实值和预测值:\n",y_test == y_predict)
    51. # 方法2:计算准确率
    52. score = estimator.score(x_test,y_test)
    53. print("准确率为:\n",score)
    54. # 最佳参数:best_params_
    55. print("最佳参数:\n",estimator.best_params_)
    56. # 最佳结果:best_score_
    57. print("最佳结果:\n",estimator.best_score_)
    58. # 最佳估计值:best_estimator_
    59. print("最佳估计值:\n",estimator.best_estimator_)
    60. # 交叉验证结果:cv_results_
    61. print("交叉验证结果:\n",estimator.cv_results_)
    62. return None
    63. def nb_news():
    64. # 用朴素贝叶斯算法对新闻进行分类
    65. # 1、获取数据
    66. news = fetch_20newsgroups(subset="all")
    67. # 2、划分数据集
    68. x_train,x_test,y_train,y_test = train_test_split(news.data,news.target)
    69. # 3、特征工程:文本特征抽取-tfidf
    70. transfer = TfidfVectorizer()
    71. x_train = transfer.fit_transform(x_train)
    72. x_test = transfer.transform(x_test)
    73. # 4、用朴素贝叶斯算法预估器流程
    74. estimator = MultinomialNB()
    75. estimator.fit(x_train,y_train)
    76. # 5、模型评估
    77. # 方法1 :直接比对真实值和预测值
    78. y_predict = estimator.predict(x_test)
    79. print("y_predict:\n", y_predict)
    80. print("直接比对真实值和预测值:\n", y_test == y_predict)
    81. # 方法2:计算准确率
    82. score = estimator.score(x_test, y_test)
    83. print("准确率为:\n", score)
    84. return None
    85. if __name__ == "__main__":
    86. # 代码1 :用KNN算法对鸢尾花进行分类
    87. # knn_iris()
    88. # 代码2 :用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证
    89. # knn_iris_gscv()
    90. # 代码3:用朴素贝叶斯算法对新闻进行分类
    91. nb_news()

    5、总结

  • 相关阅读:
    在el-table表头上引入组件不能实时传参bug
    JAVA 基础与进阶系列索引 -- JAVA 进阶系列
    Jetpack架构组件学习(3)——Activity Results API使用
    Qt+Win10使用QAxWidget控件实现远程桌面控制
    [系统安全] malloc的底层原理—ptmalloc堆概述
    被杭州某抖音代运营公司坑了
    前端学习路线(二)
    数组中的 empty
    CVPR 2022 论文和开源项目合集
    在前端使用正则对输入form表单的数据进行格式判断
  • 原文地址:https://blog.csdn.net/qq_48904748/article/details/136123416