/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.SeqLike;

public final class DecisionTreeClassificationExample$ {
    public static DecisionTreeClassificationExample$ MODULE$;

    static {
        new DecisionTreeClassificationExample$();
    }

    public void main(String[] args) {
        SparkSession spark = SparkSession$.MODULE$.builder().appName("DecisionTreeClassificationExample").getOrCreate();
        Dataset data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
        StringIndexerModel labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data);
        VectorIndexerModel featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data);
        Dataset[] datasetArray = data.randomSplit(new double[]{0.7, 0.3});
        Option option = Array$.MODULE$.unapplySeq((Object)datasetArray);
        if (option.isEmpty() || option.get() == null || ((SeqLike)option.get()).lengthCompare(2) != 0) {
            throw new MatchError((Object)datasetArray);
        }
        Dataset trainingData = (Dataset)((SeqLike)option.get()).apply(0);
        Dataset testData = (Dataset)((SeqLike)option.get()).apply(1);
        Tuple2 tuple2 = new Tuple2((Object)trainingData, (Object)testData);
        Tuple2 tuple22 = tuple2;
        Dataset trainingData2 = (Dataset)tuple22._1();
        Dataset testData2 = (Dataset)tuple22._2();
        DecisionTreeClassifier dt = (DecisionTreeClassifier)new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures");
        IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labelsArray()[0]);
        Pipeline pipeline = new Pipeline().setStages((PipelineStage[])((Object[])new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}));
        PipelineModel model = pipeline.fit(trainingData2);
        Dataset predictions = model.transform(testData2);
        predictions.select("predictedLabel", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"label", "features"})).show(5);
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        Predef$.MODULE$.println((Object)new StringBuilder(13).append("Test Error = ").append(1.0 - accuracy).toString());
        DecisionTreeClassificationModel treeModel = (DecisionTreeClassificationModel)model.stages()[2];
        Predef$.MODULE$.println((Object)new StringBuilder(36).append("Learned classification tree model:\n ").append(treeModel.toDebugString()).toString());
        spark.stop();
    }

    private DecisionTreeClassificationExample$() {
        MODULE$ = this;
    }
}

