package org.apache.spark.ml.classification;

import java.io.IOException;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import scala.jdk.javaapi.CollectionConverters;

/* loaded from: input_file:org/apache/spark/ml/classification/JavaOneVsRestSuite.class */
public class JavaOneVsRestSuite extends SharedSparkSession {
    private transient Dataset<Row> dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;

    @Override // org.apache.spark.SharedSparkSession
    @BeforeEach
    public void setUp() throws IOException {
        super.setUp();
        this.datasetRDD = this.jsc.parallelize(CollectionConverters.asJava(LogisticRegressionSuite.generateMultinomialLogisticInput(new double[]{-0.57997d, 0.912083d, -0.371077d, -0.819866d, 2.688191d, -0.16624d, -0.84355d, -0.048509d, -0.301789d, 4.170682d}, new double[]{5.843d, 3.057d, 3.758d, 1.199d}, new double[]{0.6856d, 0.1899d, 3.116d, 0.581d}, true, 3, 42)), 2);
        this.dataset = this.spark.createDataFrame(this.datasetRDD, LabeledPoint.class);
    }

    @Test
    public void oneVsRestDefaultParams() {
        OneVsRest oneVsRest = new OneVsRest();
        oneVsRest.setClassifier(new LogisticRegression());
        Assertions.assertEquals("label", oneVsRest.getLabelCol());
        Assertions.assertEquals("prediction", oneVsRest.getPredictionCol());
        OneVsRestModel fit = oneVsRest.fit(this.dataset);
        fit.transform(this.dataset).select("label", new String[]{"prediction"}).collectAsList();
        Assertions.assertEquals("label", fit.getLabelCol());
        Assertions.assertEquals("prediction", fit.getPredictionCol());
    }
}
