package org.apache.spark.sql.connect.ml;

import org.apache.spark.connect.proto.DataType;
import org.apache.spark.connect.proto.Doubles;
import org.apache.spark.connect.proto.Expression;
import org.apache.spark.connect.proto.Fetch;
import org.apache.spark.connect.proto.Ints;
import org.apache.spark.connect.proto.MlParams;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.Params;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter$;
import org.apache.spark.sql.connect.common.ProtoDataTypes$;
import org.apache.spark.sql.connect.service.SessionHolder;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: Serializer.scala */
/* loaded from: input_file:org/apache/spark/sql/connect/ml/Serializer$.class */
public final class Serializer$ {
    public static final Serializer$ MODULE$ = new Serializer$();

    public Expression.Literal serializeParam(Object obj) {
        if (obj instanceof SparseVector) {
            SparseVector sparseVector = (SparseVector) obj;
            Expression.Literal.Struct.Builder newBuilder = Expression.Literal.Struct.newBuilder();
            newBuilder.setStructType(ProtoDataTypes$.MODULE$.VectorUDT());
            newBuilder.addElements(Expression.Literal.newBuilder().setByte(0));
            newBuilder.addElements(Expression.Literal.newBuilder().setInteger(sparseVector.size()));
            newBuilder.addElements(buildIntArray(sparseVector.indices()));
            newBuilder.addElements(buildDoubleArray(sparseVector.values()));
            return Expression.Literal.newBuilder().setStruct(newBuilder).build();
        }
        if (obj instanceof DenseVector) {
            Expression.Literal.Struct.Builder newBuilder2 = Expression.Literal.Struct.newBuilder();
            newBuilder2.setStructType(ProtoDataTypes$.MODULE$.VectorUDT());
            newBuilder2.addElements(Expression.Literal.newBuilder().setByte(1));
            newBuilder2.addElements(Expression.Literal.newBuilder().setNull(ProtoDataTypes$.MODULE$.NullType()));
            newBuilder2.addElements(Expression.Literal.newBuilder().setNull(ProtoDataTypes$.MODULE$.NullType()));
            newBuilder2.addElements(buildDoubleArray(((DenseVector) obj).values()));
            return Expression.Literal.newBuilder().setStruct(newBuilder2).build();
        }
        if (obj instanceof SparseMatrix) {
            SparseMatrix sparseMatrix = (SparseMatrix) obj;
            Expression.Literal.Struct.Builder newBuilder3 = Expression.Literal.Struct.newBuilder();
            newBuilder3.setStructType(ProtoDataTypes$.MODULE$.MatrixUDT());
            newBuilder3.addElements(Expression.Literal.newBuilder().setByte(0));
            newBuilder3.addElements(Expression.Literal.newBuilder().setInteger(sparseMatrix.numRows()));
            newBuilder3.addElements(Expression.Literal.newBuilder().setInteger(sparseMatrix.numCols()));
            newBuilder3.addElements(buildIntArray(sparseMatrix.colPtrs()));
            newBuilder3.addElements(buildIntArray(sparseMatrix.rowIndices()));
            newBuilder3.addElements(buildDoubleArray(sparseMatrix.values()));
            newBuilder3.addElements(Expression.Literal.newBuilder().setBoolean(sparseMatrix.isTransposed()));
            return Expression.Literal.newBuilder().setStruct(newBuilder3).build();
        }
        if (!(obj instanceof DenseMatrix)) {
            if (obj instanceof Byte ? true : obj instanceof Short ? true : obj instanceof Integer ? true : obj instanceof Long ? true : obj instanceof Float ? true : obj instanceof Double ? true : obj instanceof Boolean ? true : obj instanceof String ? true : ScalaRunTime$.MODULE$.isArray(obj, 1)) {
                return LiteralValueProtoConverter$.MODULE$.toLiteralProto(obj);
            }
            throw new MlUnsupportedException(obj + " not supported");
        }
        DenseMatrix denseMatrix = (DenseMatrix) obj;
        Expression.Literal.Struct.Builder newBuilder4 = Expression.Literal.Struct.newBuilder();
        newBuilder4.setStructType(ProtoDataTypes$.MODULE$.MatrixUDT());
        newBuilder4.addElements(Expression.Literal.newBuilder().setByte(1));
        newBuilder4.addElements(Expression.Literal.newBuilder().setInteger(denseMatrix.numRows()));
        newBuilder4.addElements(Expression.Literal.newBuilder().setInteger(denseMatrix.numCols()));
        newBuilder4.addElements(Expression.Literal.newBuilder().setNull(ProtoDataTypes$.MODULE$.NullType()));
        newBuilder4.addElements(Expression.Literal.newBuilder().setNull(ProtoDataTypes$.MODULE$.NullType()));
        newBuilder4.addElements(buildDoubleArray(denseMatrix.values()));
        newBuilder4.addElements(Expression.Literal.newBuilder().setBoolean(denseMatrix.isTransposed()));
        return Expression.Literal.newBuilder().setStruct(newBuilder4).build();
    }

    private Expression.Literal buildIntArray(int[] iArr) {
        Ints.Builder newBuilder = Ints.newBuilder();
        ArrayOps$.MODULE$.foreach$extension(Predef$.MODULE$.intArrayOps(iArr), obj -> {
            return newBuilder.addValues(BoxesRunTime.unboxToInt(obj));
        });
        return Expression.Literal.newBuilder().setSpecializedArray(Expression.Literal.SpecializedArray.newBuilder().setInts(newBuilder).build()).build();
    }

    private Expression.Literal buildDoubleArray(double[] dArr) {
        Doubles.Builder newBuilder = Doubles.newBuilder();
        ArrayOps$.MODULE$.foreach$extension(Predef$.MODULE$.doubleArrayOps(dArr), obj -> {
            return newBuilder.addValues(BoxesRunTime.unboxToDouble(obj));
        });
        return Expression.Literal.newBuilder().setSpecializedArray(Expression.Literal.SpecializedArray.newBuilder().setDoubles(newBuilder).build()).build();
    }

    public Tuple2<Object, Class<?>>[] deserializeMethodArguments(Fetch.Method.Args[] argsArr, SessionHolder sessionHolder) {
        return (Tuple2[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(argsArr), args -> {
            if (!args.hasParam()) {
                if (args.hasInput()) {
                    return new Tuple2(MLUtils$.MODULE$.parseRelationProto(args.getInput(), sessionHolder), Dataset.class);
                }
                throw new MlUnsupportedException(args + " not supported");
            }
            Expression.Literal param = args.getParam();
            Expression.Literal.LiteralTypeCase literalTypeCase = param.getLiteralTypeCase();
            if (Expression.Literal.LiteralTypeCase.STRUCT.equals(literalTypeCase)) {
                Expression.Literal.Struct struct = param.getStruct();
                String jvmClass = struct.getStructType().getUdt().getJvmClass();
                switch (jvmClass == null ? 0 : jvmClass.hashCode()) {
                    case -442626967:
                        if ("org.apache.spark.ml.linalg.MatrixUDT".equals(jvmClass)) {
                            return new Tuple2(MLUtils$.MODULE$.deserializeMatrix(struct), Matrix.class);
                        }
                        break;
                    case 1006386087:
                        if ("org.apache.spark.ml.linalg.VectorUDT".equals(jvmClass)) {
                            return new Tuple2(MLUtils$.MODULE$.deserializeVector(struct), Vector.class);
                        }
                        break;
                }
                throw new MlUnsupportedException("Unsupported struct " + param.getStruct());
            }
            if (Expression.Literal.LiteralTypeCase.INTEGER.equals(literalTypeCase)) {
                return new Tuple2(BoxesRunTime.boxToInteger(param.getInteger()), Integer.TYPE);
            }
            if (Expression.Literal.LiteralTypeCase.FLOAT.equals(literalTypeCase)) {
                return new Tuple2(BoxesRunTime.boxToDouble(param.getFloat()), Double.TYPE);
            }
            if (Expression.Literal.LiteralTypeCase.STRING.equals(literalTypeCase)) {
                return new Tuple2(param.getString(), String.class);
            }
            if (Expression.Literal.LiteralTypeCase.DOUBLE.equals(literalTypeCase)) {
                return new Tuple2(BoxesRunTime.boxToDouble(param.getDouble()), Double.TYPE);
            }
            if (Expression.Literal.LiteralTypeCase.BOOLEAN.equals(literalTypeCase)) {
                return new Tuple2(BoxesRunTime.boxToBoolean(param.getBoolean()), Boolean.TYPE);
            }
            if (!Expression.Literal.LiteralTypeCase.ARRAY.equals(literalTypeCase)) {
                throw new MlUnsupportedException(literalTypeCase + " not supported");
            }
            Expression.Literal.Array array = param.getArray();
            DataType.KindCase kindCase = array.getElementType().getKindCase();
            if (DataType.KindCase.DOUBLE.equals(kindCase)) {
                return new Tuple2(MODULE$.parseDoubleArray(array), double[].class);
            }
            if (DataType.KindCase.STRING.equals(kindCase)) {
                return new Tuple2(MODULE$.parseStringArray(array), String[].class);
            }
            if (!DataType.KindCase.ARRAY.equals(kindCase)) {
                throw new MlUnsupportedException("Unsupported array " + param);
            }
            if (DataType.KindCase.STRING.equals(array.getElementType().getArray().getElementType().getKindCase())) {
                return new Tuple2(MODULE$.parseStringArrayArray(array), String[][].class);
            }
            throw new MlUnsupportedException("Unsupported inner array " + array);
        }, ClassTag$.MODULE$.apply(Tuple2.class));
    }

    private double[] parseDoubleArray(Expression.Literal.Array array) {
        double[] dArr = new double[array.getElementsCount()];
        for (int i = 0; i < array.getElementsCount(); i++) {
            dArr[i] = array.getElements(i).getDouble();
        }
        return dArr;
    }

    private String[] parseStringArray(Expression.Literal.Array array) {
        String[] strArr = new String[array.getElementsCount()];
        for (int i = 0; i < array.getElementsCount(); i++) {
            strArr[i] = array.getElements(i).getString();
        }
        return strArr;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [java.lang.String[], java.lang.String[][]] */
    private String[][] parseStringArrayArray(Expression.Literal.Array array) {
        ?? r0 = new String[array.getElementsCount()];
        for (int i = 0; i < array.getElementsCount(); i++) {
            r0[i] = parseStringArray(array.getElements(i).getArray());
        }
        return r0;
    }

    public MlParams serializeParams(Params params) {
        MlParams.Builder newBuilder = MlParams.newBuilder();
        ArrayOps$.MODULE$.foreach$extension(Predef$.MODULE$.refArrayOps(params.params()), param -> {
            if (!params.isSet(param)) {
                return BoxedUnit.UNIT;
            }
            return newBuilder.putParams(param.name(), MODULE$.serializeParam(params.get(param).get()));
        });
        return newBuilder.build();
    }

    private Serializer$() {
    }
}
