决策树进行鸢尾花分类的案例
背景说明:
通过IDEA + Spark 3.4.1 + sbt 1.9.3 + Spark MLlib 构建鸢尾花决策树分类预测模型,这是一个分类模型案例,通过该案例,可以快速了解Spark MLlib分类预测模型的使用方法。
依赖
ThisBuild / version := "0.1.0-SNAPSHOT" ThisBuild / scalaVersion := "2.13.11" lazy val root = (project in file(".")) .settings( name := "SparkLearning", idePackagePrefix := Some("cn.lh.spark"), libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.4.1", libraryDependencies += "org.apache.spark" %% "spark-core" % "3.4.1", libraryDependencies += "org.apache.hadoop" % "hadoop-auth" % "3.3.6", libraryDependencies += "org.apache.spark" %% "spark-streaming" % "3.4.1", libraryDependencies += "org.apache.spark" %% "spark-streaming-kafka-0-10" % "3.4.1", libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.4.1", libraryDependencies += "mysql" % "mysql-connector-java" % "8.0.30"
)
完整代码
package cn.lh.spark import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession} /** * 决策树分类器,实现鸢尾花分类 */ //case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String) // MLlibLogisticRegression 中存在该样例类,这里不用写,一个包里不存在这个样例类时需要写object MLlibDecisionTreeClassifier { def main(args: Array[String]): Unit = { val spark: SparkSession = SparkSession.builder().master("local[2]") .appName("Spark MLlib DecisionTreeClassifier").getOrCreate() val irisRDD: RDD[Iris] = spark.sparkContext.textFile("F:\\niit\\2023\\2023_2\\Spark\\codes\\data\\iris.txt") .map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble, p(1).toDouble, p(2).toDouble, p(3).toDouble), p(4).toString())) import spark.implicits._ val data: DataFrame = irisRDD.toDF() data.show() data.createOrReplaceTempView("iris") val df: DataFrame = spark.sql("select * from iris") println("鸢尾花原始数据如下:") df.map(t => t(1)+":"+t(0)).collect().foreach(println) // 处理特征和标签,以及数据分组 val labelIndexer: StringIndexerModel = new StringIndexer().setInputCol("label").setOutputCol( "indexedLabel").fit(df) val featureIndexer: VectorIndexerModel = new VectorIndexer().setInputCol("features") .setOutputCol("indexedFeatures").setMaxCategories(4).fit(df) //这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的 val labelConverter: IndexToString = new IndexToString().setInputCol("prediction") .setOutputCol("predictedLabel").setLabels(labelIndexer.labels) //接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。 val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) val dtClassifier: DecisionTreeClassifier = new DecisionTreeClassifier() .setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures") //在pipeline中进行设置 val pipelinedClassifier: Pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter)) //训练决策树模型 val modelClassifier: PipelineModel = pipelinedClassifier.fit(trainingData) //进行预测 val predictionsClassifier: DataFrame = modelClassifier.transform(testData) predictionsClassifier.select("predictedLabel", "label", "features").show(5) // 评估决策树分类模型 val evaluatorClassifier: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction").setMetricName("accuracy") val accuracy: Double = evaluatorClassifier.evaluate(predictionsClassifier) println("Test Error = " + (1.0 - accuracy)) val treeModelClassifier: DecisionTreeClassificationModel = modelClassifier.stages(2) .asInstanceOf[DecisionTreeClassificationModel] println("Learned classification tree model:\n" + treeModelClassifier.toDebugString) spark.stop() } }