package org.apache.spark.mllib.util;

import java.util.Arrays;
import java.util.Collections;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.MatrixUDT;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/spark/mllib/util/JavaMLUtilsSuite.class */
public class JavaMLUtilsSuite extends SharedSparkSession {
    @Test
    public void testConvertVectorColumnsToAndFromML() {
        Vector dense = Vectors.dense(2.0d, new double[0]);
        Dataset select = this.spark.createDataFrame(Collections.singletonList(new LabeledPoint(1.0d, dense)), LabeledPoint.class).select("label", new String[]{"features"});
        Dataset convertVectorColumnsToML = MLUtils.convertVectorColumnsToML(select, new String[0]);
        Row row = (Row) convertVectorColumnsToML.first();
        Assertions.assertEquals(RowFactory.create(new Object[]{Double.valueOf(1.0d), dense.asML()}), row);
        Assertions.assertEquals(row, (Row) MLUtils.convertVectorColumnsToML(select, new String[]{"features"}).first());
        Assertions.assertEquals(RowFactory.create(new Object[]{Double.valueOf(1.0d), dense}), (Row) MLUtils.convertVectorColumnsFromML(convertVectorColumnsToML, new String[0]).first());
    }

    @Test
    public void testConvertMatrixColumnsToAndFromML() {
        Matrix dense = Matrices.dense(2, 1, new double[]{1.0d, 2.0d});
        Dataset createDataFrame = this.spark.createDataFrame(Arrays.asList(RowFactory.create(new Object[]{Double.valueOf(1.0d), dense})), new StructType(new StructField[]{new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new MatrixUDT(), false, Metadata.empty())}));
        Dataset convertMatrixColumnsToML = MLUtils.convertMatrixColumnsToML(createDataFrame, new String[0]);
        Row row = (Row) convertMatrixColumnsToML.first();
        Assertions.assertEquals(RowFactory.create(new Object[]{Double.valueOf(1.0d), dense.asML()}), row);
        Assertions.assertEquals(row, (Row) MLUtils.convertMatrixColumnsToML(createDataFrame, new String[]{"features"}).first());
        Assertions.assertEquals(RowFactory.create(new Object[]{Double.valueOf(1.0d), dense}), (Row) MLUtils.convertMatrixColumnsFromML(convertMatrixColumnsToML, new String[0]).first());
    }
}
