• 机器学习基础之《回归与聚类算法(6)—模型保存与加载》


    一、背景

    现在我们预测每次都要重新运行一遍模型。完整的流程应该是不断调整阈值重复计算。
    当训练或者计算好一个模型之后,那么如果别人需要我们提供结果预测,就需要保存模型(主要是保存算法的参数)。

    二、sklearn模型的保存和加载API

    1、import joblib
    保存:joblib.dump(rf, "test.pkl")
        rf:是预估器estimator
        test.pkl:是保存的名字
        将预估器序列化保存在本地    
    加载:estimator = joblib.load("test.pkl")

    2、代码

    1. from sklearn.datasets import load_boston
    2. from sklearn.model_selection import train_test_split
    3. from sklearn.preprocessing import StandardScaler
    4. from sklearn.linear_model import LinearRegression, SGDRegressor, Ridge
    5. from sklearn.metrics import mean_squared_error
    6. import joblib
    7. def linear1():
    8. """
    9. 正规方程的优化方法对波士顿房价进行预测
    10. """
    11. # 1、获取数据
    12. boston = load_boston()
    13. # 2、划分数据集
    14. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    15. # 3、标准化
    16. transfer = StandardScaler()
    17. x_train = transfer.fit_transform(x_train)
    18. x_test = transfer.transform(x_test)
    19. # 4、预估器
    20. estimator = LinearRegression()
    21. estimator.fit(x_train, y_train)
    22. # 5、得出模型
    23. print("正规方程-权重系数为:\n", estimator.coef_)
    24. print("正规方程-偏置为:\n", estimator.intercept_)
    25. # 6、模型评估
    26. y_predict = estimator.predict(x_test)
    27. print("预测房价:\n", y_predict)
    28. error = mean_squared_error(y_test, y_predict)
    29. print("正规方程-均方误差为:\n", error)
    30. return None
    31. def linear2():
    32. """
    33. 梯度下降的优化方法对波士顿房价进行预测
    34. """
    35. # 1、获取数据
    36. boston = load_boston()
    37. # 2、划分数据集
    38. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    39. # 3、标准化
    40. transfer = StandardScaler()
    41. x_train = transfer.fit_transform(x_train)
    42. x_test = transfer.transform(x_test)
    43. # 4、预估器
    44. estimator = SGDRegressor()
    45. estimator.fit(x_train, y_train)
    46. # 5、得出模型
    47. print("梯度下降-权重系数为:\n", estimator.coef_)
    48. print("梯度下降-偏置为:\n", estimator.intercept_)
    49. # 6、模型评估
    50. y_predict = estimator.predict(x_test)
    51. print("预测房价:\n", y_predict)
    52. error = mean_squared_error(y_test, y_predict)
    53. print("梯度下降-均方误差为:\n", error)
    54. return None
    55. def linear3():
    56. """
    57. 岭回归对波士顿房价进行预测
    58. """
    59. # 1、获取数据
    60. boston = load_boston()
    61. # 2、划分数据集
    62. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    63. # 3、标准化
    64. transfer = StandardScaler()
    65. x_train = transfer.fit_transform(x_train)
    66. x_test = transfer.transform(x_test)
    67. # 4、预估器
    68. estimator = Ridge()
    69. estimator.fit(x_train, y_train)
    70. # 保存模型
    71. joblib.dump(estimator, "my_ridge.pkl")
    72. # 5、得出模型
    73. print("岭回归-权重系数为:\n", estimator.coef_)
    74. print("岭回归-偏置为:\n", estimator.intercept_)
    75. # 6、模型评估
    76. y_predict = estimator.predict(x_test)
    77. print("预测房价:\n", y_predict)
    78. error = mean_squared_error(y_test, y_predict)
    79. print("岭回归-均方误差为:\n", error)
    80. return None
    81. def linear4():
    82. """
    83. 岭回归对波士顿房价进行预测
    84. """
    85. # 1、获取数据
    86. boston = load_boston()
    87. # 2、划分数据集
    88. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    89. # 3、标准化
    90. transfer = StandardScaler()
    91. x_train = transfer.fit_transform(x_train)
    92. x_test = transfer.transform(x_test)
    93. # 加载模型
    94. estimator = joblib.load("my_ridge.pkl")
    95. # 5、得出模型
    96. print("岭回归-权重系数为:\n", estimator.coef_)
    97. print("岭回归-偏置为:\n", estimator.intercept_)
    98. # 6、模型评估
    99. y_predict = estimator.predict(x_test)
    100. print("预测房价:\n", y_predict)
    101. error = mean_squared_error(y_test, y_predict)
    102. print("岭回归-均方误差为:\n", error)
    103. return None
    104. if __name__ == "__main__":
    105. # 代码1:正规方程的优化方法对波士顿房价进行预测
    106. linear1()
    107. # 代码2:梯度下降的优化方法对波士顿房价进行预测
    108. linear2()
    109. # 代码3:岭回归对波士顿房价进行预测
    110. linear3()
    111. # 代码4:加载模型
    112. linear4()

  • 相关阅读:
    UI 自动化测试 —— selenium的简单介绍和使用
    端口扫描技术
    CSS 布局 (三) 浮动、定位、多列布局
    unity模型制作(终章)
    Java—简单斗地主(集合练习)
    2023-11 | 短视频批量下载/爬取某个用户的所有视频 | Python
    坦克大战游戏开发中的设计模式总结
    【Python】OpenCV-图片差异检测与标注
    百题千解计划【CSDN每日一练】计数问题(附解析+多种实现方法:Python、Java、C、C++、JavaScript、C#、go)
    Linux命令(22)之chage
  • 原文地址:https://blog.csdn.net/csj50/article/details/134394889