• SparkMlib 之随机森林及其案例


    什么是随机森林?

    随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由多个决策树构成,相比于单个决策树算法,它分类、预测效果更好,不容易出现过度拟合的情况。

    常应用于以下类型的场景:

    1. 预测用户贷款是否能够按时还款;
    2. 预测用户是否会购买某件商品等等

    官网:分类和回归

    随机森林的优缺点

    优点:

    1. 可以处理高纬度的数据;

    2. 训练之前不需要特意的做特征选择;

    3. 建立很多树,预防了过拟合风险;

    缺点:

    1. 计算量相对于决策树很大,性能开销很大。

    2. 可能会导致有些数据集没有训练到,但这种几率很小。

    3. 分裂的时候,偏向于选择取值较多的特征。

    随机森林示例——鸢尾花分类

    数据集下载:

    链接:
    https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 
    
    提取码:
    lz3l
    
    • 1
    • 2
    • 3
    • 4
    • 5

    数据集介绍:

    iris.scale.txtlibsvm 格式的鸢尾花数据集,共有五个字段。第一个为标签字段,后四个为特征字段。

    libsvm 格式参考:机器学习:libsvm数据格式

    将数据集中的随机百分之70作为训练集,剩余的作为测试集。

    使用 SparkSQL 的方式读取 libsvm 格式的文件会自动生成 labelfeatures 结构的数据,如下所示:

    val data: DataFrame = spark.read.format("libsvm").load("iris.scale.txt")
    
    data.show()
    
    • 1
    • 2
    • 3

    需求实现:

    import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature._
    import org.apache.spark.ml.{Pipeline, PipelineModel}
    import org.apache.spark.sql.{DataFrame, SparkSession}
    
    object Iris {
    
        def main(args: Array[String]): Unit = {
    
            val spark: SparkSession = SparkSession.builder().appName("Iris").master("local[*]").getOrCreate()
    
            // 加载 libsvm 格式文件的数据
            val data: DataFrame = spark.read.format("libsvm").load("C:\\Users\\Administrator\\Desktop\\iris.scale.txt")
    
            data.show()
    
            // 1.构建标签列转换对象
            val labelIndexer: StringIndexerModel = new StringIndexer()
                    .setInputCol("label")
                    .setOutputCol("indexedLabel")
                    .fit(data)
    
            // 2.构建特征列转换对象,设置特征列数量
            val featureIndexer: VectorIndexerModel = new VectorIndexer()
                    .setInputCol("features")
                    .setOutputCol("indexedFeatures")
                    .setMaxCategories(4)
                    .fit(data)
    
            // 3.将随机百分之70作为训练集,其余为测试集
            val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
    
            // 4.创建随机森林对象,设置标签列与特征列以及决策树的个数
            val rf: RandomForestClassifier = new RandomForestClassifier()
                    .setLabelCol("indexedLabel")
                    .setFeaturesCol("indexedFeatures")
                    .setNumTrees(10)
    
            // 5.设置预测列标签
            val labelConverter: IndexToString = new IndexToString()
                    .setInputCol("prediction")
                    .setOutputCol("predictedLabel")
                    .setLabels(labelIndexer.labelsArray(0))
    
            // 6.管道组装
            val pipeline: Pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
    
            // 7.模型训练
            val model: PipelineModel = pipeline.fit(trainingData)
    
            // 8.模型预测
            val predictions: DataFrame = model.transform(testData)
    
            // 9.模型评估
            predictions.select("predictedLabel", "label", "features").show()
    
            // 10.创建错误率的计算对象
            val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
                    .setLabelCol("indexedLabel")
                    .setPredictionCol("prediction")
                    .setMetricName("accuracy")
    
            // 11.计算错误率
            val accuracy: Double = evaluator.evaluate(predictions)
            println(s"Test Error = ${(1.0 - accuracy)}")
    
            // 12.打印随机森林模型
            val rfModel: RandomForestClassificationModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
            println(s"Learned classification forest model:\n ${rfModel.toDebugString}")
    
        }
    
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
  • 相关阅读:
    小程序+egg来实现获取用户手机号
    传输中的差错检验技术
    笔试题-构建非二叉树,且非递归遍历-利用栈
    ETL实现实时文件监听
    c# Class vs Structure
    端到端流程总结
    AndroidStudio测试类无法运行
    2022年十次最大的云中断
    聚焦千兆光模块和万兆光模块的测试技术及设备
    Spring-boot初级
  • 原文地址:https://blog.csdn.net/weixin_46389691/article/details/128113952