package org.apache.spark.ml.feature;

import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.Attribute$;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.BinaryAttribute;
import org.apache.spark.ml.attribute.BinaryAttribute$;
import org.apache.spark.ml.attribute.NominalAttribute;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructField;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.ArrayOps$;
import scala.collection.IterableOps;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: OneHotEncoder.scala */
/* loaded from: input_file:org/apache/spark/ml/feature/OneHotEncoderCommon$.class */
public final class OneHotEncoderCommon$ {
    public static final OneHotEncoderCommon$ MODULE$ = new OneHotEncoderCommon$();

    private Option<String[]> genOutputAttrNames(StructField structField) {
        Attribute fromStructField = Attribute$.MODULE$.fromStructField(structField);
        if (fromStructField instanceof NominalAttribute) {
            NominalAttribute nominalAttribute = (NominalAttribute) fromStructField;
            return nominalAttribute.values().isDefined() ? nominalAttribute.values() : nominalAttribute.numValues().isDefined() ? nominalAttribute.numValues().map(obj -> {
                return $anonfun$genOutputAttrNames$1(BoxesRunTime.unboxToInt(obj));
            }) : None$.MODULE$;
        }
        if (fromStructField instanceof BinaryAttribute) {
            BinaryAttribute binaryAttribute = (BinaryAttribute) fromStructField;
            return binaryAttribute.values().isDefined() ? binaryAttribute.values() : new Some(Array$.MODULE$.tabulate(2, obj2 -> {
                return Integer.toString(BoxesRunTime.unboxToInt(obj2));
            }, ClassTag$.MODULE$.apply(String.class)));
        }
        if (fromStructField instanceof NumericAttribute) {
            throw new RuntimeException("The input column " + structField.name() + " cannot be continuous-value.");
        }
        return None$.MODULE$;
    }

    private AttributeGroup genOutputAttrGroup(Option<String[]> option, String str) {
        return (AttributeGroup) option.map(strArr -> {
            return new AttributeGroup(str, (Attribute[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(strArr), str2 -> {
                return BinaryAttribute$.MODULE$.defaultAttr().withName(str2);
            }, ClassTag$.MODULE$.apply(Attribute.class)));
        }).getOrElse(() -> {
            return new AttributeGroup(str);
        });
    }

    public StructField transformOutputColumnSchema(StructField structField, String str, boolean z, boolean z2) {
        return genOutputAttrGroup(genOutputAttrNames(structField).map(strArr -> {
            if (!z || z2) {
                return (z || !z2) ? strArr : (String[]) ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.refArrayOps(strArr), new $colon.colon("invalidValues", Nil$.MODULE$), ClassTag$.MODULE$.apply(String.class));
            }
            Predef$.MODULE$.require(strArr.length > 1, () -> {
                return "The input column " + structField.name() + " should have at least two distinct values.";
            });
            return (String[]) ArrayOps$.MODULE$.dropRight$extension(Predef$.MODULE$.refArrayOps(strArr), 1);
        }), str).toStructField();
    }

    public boolean transformOutputColumnSchema$default$4() {
        return false;
    }

    public Seq<AttributeGroup> getOutputAttrGroupFromData(Dataset<?> dataset, Seq<String> seq, Seq<String> seq2, boolean z) {
        Seq seq3 = (Seq) seq.map(str -> {
            return functions$.MODULE$.col(str).cast(DoubleType$.MODULE$);
        });
        int length = seq3.length();
        ArrayOps$ arrayOps$ = ArrayOps$.MODULE$;
        Predef$ predef$ = Predef$.MODULE$;
        RDD map = dataset.select(seq3).rdd().map(row -> {
            return (double[]) RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), length).map(i -> {
                return row.getDouble(i);
            }).toArray(ClassTag$.MODULE$.Double());
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
        double[] dArr = new double[length];
        return (Seq) ((IterableOps) seq2.zip(Predef$.MODULE$.wrapIntArray((int[]) arrayOps$.map$extension(predef$.doubleArrayOps((double[]) map.treeAggregate(dArr, (dArr2, dArr3) -> {
            RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), length).foreach$mVc$sp(i -> {
                double d = dArr3[i];
                Predef$.MODULE$.assert(d <= ((double) Integer.MAX_VALUE), () -> {
                    return "OneHotEncoder only supports up to " + 2147483647 + " indices, but got " + d + ".";
                });
                Predef$.MODULE$.assert(d >= 0.0d && d == ((double) ((int) d)), () -> {
                    return "Values from column " + seq.apply(i) + " must be indices, but got " + d + ".";
                });
                dArr2[i] = scala.math.package$.MODULE$.max(dArr2[i], d);
            });
            return dArr2;
        }, (dArr4, dArr5) -> {
            RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), length).foreach$mVc$sp(i -> {
                dArr4[i] = scala.math.package$.MODULE$.max(dArr4[i], dArr5[i]);
            });
            return dArr4;
        }, map.treeAggregate$default$4(dArr), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)))), d -> {
            return ((int) d) + 1;
        }, ClassTag$.MODULE$.Int())))).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            return MODULE$.createAttrGroupForAttrNames((String) tuple2._1(), tuple2._2$mcI$sp(), z, false);
        });
    }

    public AttributeGroup createAttrGroupForAttrNames(String str, int i, boolean z, boolean z2) {
        String[] strArr = (String[]) Array$.MODULE$.tabulate(i, obj -> {
            return Integer.toString(BoxesRunTime.unboxToInt(obj));
        }, ClassTag$.MODULE$.apply(String.class));
        return genOutputAttrGroup(new Some((!z || z2) ? (z || !z2) ? strArr : (String[]) ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.refArrayOps(strArr), new $colon.colon("invalidValues", Nil$.MODULE$), ClassTag$.MODULE$.apply(String.class)) : (String[]) ArrayOps$.MODULE$.dropRight$extension(Predef$.MODULE$.refArrayOps(strArr), 1)), str);
    }

    public static final /* synthetic */ String[] $anonfun$genOutputAttrNames$1(int i) {
        return (String[]) Array$.MODULE$.tabulate(i, obj -> {
            return Integer.toString(BoxesRunTime.unboxToInt(obj));
        }, ClassTag$.MODULE$.apply(String.class));
    }

    private OneHotEncoderCommon$() {
    }
}
