package org.apache.spark.ml.stat;

import java.io.IOException;
import java.util.ArrayList;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/spark/ml/stat/JavaSummarizerSuite.class */
public class JavaSummarizerSuite extends SharedSparkSession {
    private transient Dataset<Row> dataset;

    @Override // org.apache.spark.SharedSparkSession
    @BeforeEach
    public void setUp() throws IOException {
        super.setUp();
        ArrayList arrayList = new ArrayList();
        arrayList.add(new LabeledPoint(0.0d, Vectors.dense(1.0d, new double[]{2.0d})));
        arrayList.add(new LabeledPoint(0.0d, Vectors.dense(3.0d, new double[]{4.0d})));
        this.dataset = this.spark.createDataFrame(this.jsc.parallelize(arrayList, 2), LabeledPoint.class);
    }

    @Test
    public void testSummarizer() {
        this.dataset.select(new Column[]{functions.col("features")});
        Row struct = ((Row) this.dataset.select(new Column[]{Summarizer.metrics(new String[]{"mean", "max", "count"}).summary(functions.col("features"))}).first()).getStruct(0);
        Vector vector = (Vector) struct.getAs("mean");
        Vector vector2 = (Vector) struct.getAs("max");
        Assertions.assertEquals(2L, ((Long) struct.getAs("count")).longValue());
        Assertions.assertArrayEquals(new double[]{2.0d, 3.0d}, vector.toArray(), 0.0d);
        Assertions.assertArrayEquals(new double[]{3.0d, 4.0d}, vector2.toArray(), 0.0d);
    }
}
