package org.apache.spark.ml.feature;

import java.util.Arrays;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/spark/ml/feature/JavaTargetEncoderSuite.class */
public class JavaTargetEncoderSuite extends SharedSparkSession {
    @Test
    public void testTargetEncoderBinary() {
        Dataset createDataFrame = this.spark.createDataFrame(Arrays.asList(RowFactory.create(new Object[]{(short) 0, 3, Double.valueOf(5.0d), Double.valueOf(0.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(0.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(0.3611111111111111d), Double.valueOf(0.07407407407407406d), Double.valueOf(0.3611111111111111d)}), RowFactory.create(new Object[]{(short) 1, 4, Double.valueOf(5.0d), Double.valueOf(1.0d), Double.valueOf(0.6666666666666666d), Double.valueOf(1.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(0.6111111111111112d), Double.valueOf(0.888888888888889d), Double.valueOf(0.3611111111111111d)}), RowFactory.create(new Object[]{(short) 2, 3, Double.valueOf(5.0d), Double.valueOf(0.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(0.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(0.3611111111111111d), Double.valueOf(0.07407407407407406d), Double.valueOf(0.3611111111111111d)}), RowFactory.create(new Object[]{(short) 0, 4, Double.valueOf(6.0d), Double.valueOf(1.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(1.0d), Double.valueOf(0.6666666666666666d), Double.valueOf(0.3611111111111111d), Double.valueOf(0.888888888888889d), Double.valueOf(0.6111111111111112d)}), RowFactory.create(new Object[]{(short) 1, 3, Double.valueOf(6.0d), Double.valueOf(0.0d), Double.valueOf(0.6666666666666666d), Double.valueOf(0.0d), Double.valueOf(0.6666666666666666d), Double.valueOf(0.6111111111111112d), Double.valueOf(0.07407407407407406d), Double.valueOf(0.6111111111111112d)}), RowFactory.create(new Object[]{(short) 2, 4, Double.valueOf(6.0d), Double.valueOf(1.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(1.0d), Double.valueOf(0.6666666666666666d), Double.valueOf(0.3611111111111111d), Double.valueOf(0.888888888888889d), Double.valueOf(0.6111111111111112d)}), RowFactory.create(new Object[]{(short) 0, 3, Double.valueOf(7.0d), Double.valueOf(0.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(0.0d), Double.valueOf(0.0d), Double.valueOf(0.3611111111111111d), Double.valueOf(0.07407407407407406d), Double.valueOf(0.2222222222222222d)}), RowFactory.create(new Object[]{(short) 1, 4, Double.valueOf(8.0d), Double.valueOf(1.0d), Double.valueOf(0.6666666666666666d), Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(0.6111111111111112d), Double.valueOf(0.888888888888889d), Double.valueOf(0.7222222222222222d)}), RowFactory.create(new Object[]{(short) 2, 3, null, Double.valueOf(0.0d), Double.valueOf(0.3333333333333333d), Double.valueOf(0.0d), Double.valueOf(0.0d), Double.valueOf(0.3611111111111111d), Double.valueOf(0.07407407407407406d), Double.valueOf(0.2222222222222222d)})), DataTypes.createStructType(new StructField[]{DataTypes.createStructField("input1", DataTypes.ShortType, true), DataTypes.createStructField("input2", DataTypes.IntegerType, true), DataTypes.createStructField("input3", DataTypes.DoubleType, true), DataTypes.createStructField("label", DataTypes.DoubleType, false), DataTypes.createStructField("expected1", DataTypes.DoubleType, false), DataTypes.createStructField("expected2", DataTypes.DoubleType, false), DataTypes.createStructField("expected3", DataTypes.DoubleType, false), DataTypes.createStructField("smoothing1", DataTypes.DoubleType, false), DataTypes.createStructField("smoothing2", DataTypes.DoubleType, false), DataTypes.createStructField("smoothing3", DataTypes.DoubleType, false)}));
        TargetEncoderModel fit = new TargetEncoder().setInputCols(new String[]{"input1", "input2", "input3"}).setOutputCols(new String[]{"output1", "output2", "output3"}).setTargetType("binary").fit(createDataFrame);
        Dataset transform = fit.transform(createDataFrame);
        Assertions.assertEquals(transform.select("output1", new String[]{"output2", "output3"}).collectAsList(), transform.select("expected1", new String[]{"expected2", "expected3"}).collectAsList());
        Dataset transform2 = fit.setSmoothing(1.0d).transform(createDataFrame);
        Assertions.assertEquals(transform2.select("output1", new String[]{"output2", "output3"}).collectAsList(), transform2.select("smoothing1", new String[]{"smoothing2", "smoothing3"}).collectAsList());
    }

    @Test
    public void testTargetEncoderContinuous() {
        Dataset createDataFrame = this.spark.createDataFrame(Arrays.asList(RowFactory.create(new Object[]{(short) 0, 3, Double.valueOf(5.0d), Double.valueOf(10.0d), Double.valueOf(40.0d), Double.valueOf(50.0d), Double.valueOf(20.0d), Double.valueOf(42.5d), Double.valueOf(50.0d), Double.valueOf(27.5d)}), RowFactory.create(new Object[]{(short) 1, 4, Double.valueOf(5.0d), Double.valueOf(20.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(20.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(27.5d)}), RowFactory.create(new Object[]{(short) 2, 3, Double.valueOf(5.0d), Double.valueOf(30.0d), Double.valueOf(60.0d), Double.valueOf(50.0d), Double.valueOf(20.0d), Double.valueOf(57.5d), Double.valueOf(50.0d), Double.valueOf(27.5d)}), RowFactory.create(new Object[]{(short) 0, 4, Double.valueOf(6.0d), Double.valueOf(40.0d), Double.valueOf(40.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(42.5d), Double.valueOf(50.0d), Double.valueOf(50.0d)}), RowFactory.create(new Object[]{(short) 1, 3, Double.valueOf(6.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(50.0d)}), RowFactory.create(new Object[]{(short) 2, 4, Double.valueOf(6.0d), Double.valueOf(60.0d), Double.valueOf(60.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(57.5d), Double.valueOf(50.0d), Double.valueOf(50.0d)}), RowFactory.create(new Object[]{(short) 0, 3, Double.valueOf(7.0d), Double.valueOf(70.0d), Double.valueOf(40.0d), Double.valueOf(50.0d), Double.valueOf(70.0d), Double.valueOf(42.5d), Double.valueOf(50.0d), Double.valueOf(60.0d)}), RowFactory.create(new Object[]{(short) 1, 4, Double.valueOf(8.0d), Double.valueOf(80.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(80.0d), Double.valueOf(50.0d), Double.valueOf(50.0d), Double.valueOf(65.0d)}), RowFactory.create(new Object[]{(short) 2, 3, null, Double.valueOf(90.0d), Double.valueOf(60.0d), Double.valueOf(50.0d), Double.valueOf(90.0d), Double.valueOf(57.5d), Double.valueOf(50.0d), Double.valueOf(70.0d)})), DataTypes.createStructType(new StructField[]{DataTypes.createStructField("input1", DataTypes.ShortType, true), DataTypes.createStructField("input2", DataTypes.IntegerType, true), DataTypes.createStructField("input3", DataTypes.DoubleType, true), DataTypes.createStructField("label", DataTypes.DoubleType, false), DataTypes.createStructField("expected1", DataTypes.DoubleType, false), DataTypes.createStructField("expected2", DataTypes.DoubleType, false), DataTypes.createStructField("expected3", DataTypes.DoubleType, false), DataTypes.createStructField("smoothing1", DataTypes.DoubleType, false), DataTypes.createStructField("smoothing2", DataTypes.DoubleType, false), DataTypes.createStructField("smoothing3", DataTypes.DoubleType, false)}));
        TargetEncoderModel fit = new TargetEncoder().setInputCols(new String[]{"input1", "input2", "input3"}).setOutputCols(new String[]{"output1", "output2", "output3"}).setTargetType("continuous").fit(createDataFrame);
        Dataset transform = fit.transform(createDataFrame);
        Assertions.assertEquals(transform.select("output1", new String[]{"output2", "output3"}).collectAsList(), transform.select("expected1", new String[]{"expected2", "expected3"}).collectAsList());
        Dataset transform2 = fit.setSmoothing(1.0d).transform(createDataFrame);
        Assertions.assertEquals(transform2.select("output1", new String[]{"output2", "output3"}).collectAsList(), transform2.select("smoothing1", new String[]{"smoothing2", "smoothing3"}).collectAsList());
    }
}
