Деревья решений, случайный лес и повышение градиента в Scala Spark
В машинном обучении деревья решений, случайные леса и повышение градиента являются популярными алгоритмами для решения задач классификации и регрессии. Эти алгоритмы являются мощными инструментами для построения прогностических моделей, которые могут делать точные прогнозы на основе новых, невидимых данных.
В этом сообщении блога мы рассмотрим, как реализовать эти алгоритмы в Scala Spark.
Предыдущая глава:
Я буду использовать данные о раке от Kaggle.
Подготовка данных
- Во-первых, давайте создадим сеанс Spark.
import org.apache.spark.sql.SparkSession import org.apache.log4j._ Logger.getLogger("org").setLevel(Level.ERROR) // Create a SparkSession val spark = SparkSession.builder().getOrCreate()
- Прочитайте файл CSV и удалите ненужные столбцы.
//read csv var data = spark.read.option("header","true").option("inferSchema","true"). format("csv").load("data_cancer.csv") // drop last column data = data.drop(data.columns.last) data.printSchema()
- Разделите данные как обучающие и тестовые подмножества. Результирующая переменная
splits
представляет собой массив фреймов данных, где первый фрейм данных содержит 70 % строк изdata
, а второй фрейм данных содержит оставшиеся 30 % строк.
val splits = data.randomSplit(Array(0.7,0.3)) val train = splits(0) val test = splits(1) val train_rows = train.count() val test_rows = test.count() println("Training rows: " + train_rows + " Test rows: " + test_rows) //Training rows: 385 Test rows: 184
- Соберите ассемблер. Во-первых, мы исключим ненужный столбец id и целевую функцию diagnose.
VectorAssembler
— это преобразователь, который объединяет несколько столбцов DataFrame в один векторный столбец.
import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.sql.functions.array_except // exclude columns in inputs val colsToExclude = Seq("id", "diagnosis") val inputCols = data.columns.filter(!colsToExclude.contains(_)) // build assembler val assembler = new VectorAssembler().setInputCols(inputCols).setOutputCol("features") val training = assembler.transform(train).select($"features", $"diagnosis".alias("label")) training.show()
Древо решений
- Постройте модель дерева решений и обучите ее.
import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.classification.DecisionTreeClassifier val dt = new DecisionTreeClassifier().setLabelCol("label").setFeaturesCol("features") val modelDt = dt.fit(training)
- Давайте повторим для проверки то, что мы сделали для обучения. Затем, используя метод
transform
, мы получим прогнозы.
val testing = assembler.transform(test).select($"features", $"diagnosis".alias("trueLabel")) val predictionDt = modelDt.transform(testing) val predictedDt = predictionDt.select("features","prediction","probability","trueLabel") predictedDt.show()
- Наконец, давайте оценим модель.
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator val evaluator = new BinaryClassificationEvaluator().setLabelCol("trueLabel").setRawPredictionCol("prediction") val metricDt = evaluator.evaluate(predictionDt) //evaluator: org.apache.spark.ml.evaluation.BinaryClassificationEvaluator = BinaryClassificationEvaluator: uid=binEval_01ea9dc7805b, me //tricName=areaUnderROC, numBins=1000 //metricDt: Double = 0.9492124718739956
Случайный лес
Мы просто повторяем то, что делали для деревьев решений. Единственная разница в том, что на этот раз мы будем использовать класс RandomForestClassifier
.
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} val rf = new RandomForestClassifier().setLabelCol("label").setFeaturesCol("features") val modelRf = rf.fit(training) val predictionRf = modelRf.transform(testing) val predictedRf = predictionRf.select("features","prediction","probability","trueLabel") val evaluatorRf = new BinaryClassificationEvaluator().setLabelCol("trueLabel").setRawPredictionCol("prediction") val metricRf = evaluatorRf.evaluate(predictionRf) //metricDt: Double = 0.9117251579938146
Повышение градиента
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} val gb = new GBTClassifier().setLabelCol("label").setFeaturesCol("features") val modelGb = gb.fit(training) val predictionGb = modelGb.transform(testing) val predictedGb = predictionGb.select("features","prediction","probability","trueLabel") val evaluatorGb = new BinaryClassificationEvaluator().setLabelCol("trueLabel").setRawPredictionCol("prediction") val metricGb = evaluatorGb.evaluate(predictionRf) //metricGb: Double = 0.9643206256109482
Читать далее
Источники
https://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier