package org.apache.spark.mllib.clustering;

import java.util.Arrays;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/spark/mllib/clustering/JavaKMeansSuite.class */
public class JavaKMeansSuite extends SharedSparkSession {
    @Test
    public void runKMeansUsingStaticMethods() {
        List asList = Arrays.asList(Vectors.dense(1.0d, new double[]{2.0d, 6.0d}), Vectors.dense(1.0d, new double[]{3.0d, 0.0d}), Vectors.dense(1.0d, new double[]{4.0d, 6.0d}));
        Vector dense = Vectors.dense(1.0d, new double[]{3.0d, 4.0d});
        JavaRDD parallelize = this.jsc.parallelize(asList, 2);
        KMeansModel train = KMeans.train(parallelize.rdd(), 1, 1, KMeans.K_MEANS_PARALLEL());
        Assertions.assertEquals(1, train.clusterCenters().length);
        Assertions.assertEquals(dense, train.clusterCenters()[0]);
        Assertions.assertEquals(dense, KMeans.train(parallelize.rdd(), 1, 1, KMeans.RANDOM()).clusterCenters()[0]);
    }

    @Test
    public void runKMeansUsingConstructor() {
        List asList = Arrays.asList(Vectors.dense(1.0d, new double[]{2.0d, 6.0d}), Vectors.dense(1.0d, new double[]{3.0d, 0.0d}), Vectors.dense(1.0d, new double[]{4.0d, 6.0d}));
        Vector dense = Vectors.dense(1.0d, new double[]{3.0d, 4.0d});
        JavaRDD parallelize = this.jsc.parallelize(asList, 2);
        KMeansModel run = new KMeans().setK(1).setMaxIterations(5).run(parallelize.rdd());
        Assertions.assertEquals(1, run.clusterCenters().length);
        Assertions.assertEquals(dense, run.clusterCenters()[0]);
        Assertions.assertEquals(dense, new KMeans().setK(1).setMaxIterations(1).setInitializationMode(KMeans.RANDOM()).run(parallelize.rdd()).clusterCenters()[0]);
    }

    @Test
    public void testPredictJavaRDD() {
        JavaRDD parallelize = this.jsc.parallelize(Arrays.asList(Vectors.dense(1.0d, new double[]{2.0d, 6.0d}), Vectors.dense(1.0d, new double[]{3.0d, 0.0d}), Vectors.dense(1.0d, new double[]{4.0d, 6.0d})), 2);
        new KMeans().setK(1).setMaxIterations(5).run(parallelize.rdd()).predict(parallelize).first();
    }
}
