/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.math.IntMath;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class HiveExpandDistinctAggregatesRule
extends RelOptRule {
    public static final HiveExpandDistinctAggregatesRule INSTANCE = new HiveExpandDistinctAggregatesRule(HiveAggregate.class, HiveRelFactories.HIVE_PROJECT_FACTORY);
    private static RelFactories.ProjectFactory projFactory;
    protected static final Logger LOG;
    RelOptCluster cluster = null;
    RexBuilder rexBuilder = null;

    public HiveExpandDistinctAggregatesRule(Class<? extends Aggregate> clazz, RelFactories.ProjectFactory projectFactory) {
        super(HiveExpandDistinctAggregatesRule.operand(clazz, (RelOptRuleOperandChildren)HiveExpandDistinctAggregatesRule.any()));
        projFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        int numCountDistinct = this.getNumCountDistinctCall(aggregate);
        if (numCountDistinct == 0 || aggregate.getGroupType() != Aggregate.Group.SIMPLE) {
            return;
        }
        int nonDistinctCount = 0;
        ArrayList<List<Integer>> argListList = new ArrayList<List<Integer>>();
        LinkedHashSet argListSets = new LinkedHashSet();
        ImmutableBitSet.Builder newGroupSet = ImmutableBitSet.builder();
        newGroupSet.addAll(aggregate.getGroupSet());
        for (Object aggCall : aggregate.getAggCallList()) {
            if (!aggCall.isDistinct()) {
                ++nonDistinctCount;
                continue;
            }
            ArrayList<Integer> argList = new ArrayList<Integer>();
            for (Integer arg : aggCall.getArgList()) {
                argList.add(arg);
                newGroupSet.set(arg.intValue());
            }
            argListList.add(argList);
            argListSets.add(argList);
        }
        Preconditions.checkArgument((argListSets.size() > 0 ? 1 : 0) != 0, (Object)"containsDistinctCall lied");
        if (numCountDistinct > 1 && numCountDistinct == aggregate.getAggCallList().size()) {
            LOG.debug("Trigger countDistinct rewrite. numCountDistinct is " + numCountDistinct);
            this.cluster = aggregate.getCluster();
            this.rexBuilder = this.cluster.getRexBuilder();
            try {
                call.transformTo(this.convert(aggregate, argListList, newGroupSet.build()));
            }
            catch (CalciteSemanticException e) {
                LOG.debug(e.toString());
                throw new RuntimeException((Throwable)((Object)e));
            }
            return;
        }
        RelMetadataQuery mq = call.getMetadataQuery();
        if (nonDistinctCount == 0 && argListSets.size() == 1) {
            for (Integer arg : (List)argListSets.iterator().next()) {
                Set colOrigs = mq.getColumnOrigins(aggregate.getInput(), arg.intValue());
                if (null == colOrigs) continue;
                for (RelColumnOrigin colOrig : colOrigs) {
                    RelOptHiveTable hiveTbl = (RelOptHiveTable)colOrig.getOriginTable();
                    if (!hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) continue;
                    return;
                }
            }
            RelNode converted = this.convertMonopole(aggregate, (List)argListSets.iterator().next());
            call.transformTo(converted);
            return;
        }
    }

    private RelNode convert(Aggregate aggregate, List<List<Integer>> argList, ImmutableBitSet newGroupSet) throws CalciteSemanticException {
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        ArrayList<List<Integer>> cleanArgList = new ArrayList<List<Integer>>();
        Aggregate groupingSets = this.createGroupingSets(aggregate, argList, cleanArgList, map, newGroupSet);
        return this.createCount(groupingSets, argList, cleanArgList, map, aggregate.getGroupSet(), newGroupSet);
    }

    private int getGroupingIdValue(List<Integer> list, ImmutableBitSet originalGroupSet, ImmutableBitSet newGroupSet, int groupCount) {
        int ind = IntMath.pow((int)2, (int)groupCount) - 1;
        Iterator<Integer> iterator = originalGroupSet.iterator();
        while (iterator.hasNext()) {
            int pos = (Integer)iterator.next();
            ind &= ~(1 << groupCount - newGroupSet.indexOf(pos) - 1);
        }
        for (int i : list) {
            ind &= ~(1 << groupCount - newGroupSet.indexOf(i) - 1);
        }
        return ind;
    }

    private RelNode createCount(Aggregate aggr, List<List<Integer>> argList, List<List<Integer>> cleanArgList, Map<Integer, Integer> map, ImmutableBitSet originalGroupSet, ImmutableBitSet newGroupSet) throws CalciteSemanticException {
        int i;
        List originalInputRefs = aggr.getRowType().getFieldList().stream().map(input -> new RexInputRef(input.getIndex(), input.getType())).collect(Collectors.toList());
        ArrayList gbChildProjLst = Lists.newArrayList();
        for (List<Integer> list : cleanArgList) {
            RexNode condition = this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{(RexNode)originalInputRefs.get(originalInputRefs.size() - 1), this.rexBuilder.makeExactLiteral(new BigDecimal(this.getGroupingIdValue(list, originalGroupSet, newGroupSet, aggr.getGroupCount())))});
            if (list.size() == 1) {
                int pos = list.get(0);
                RexNode notNull = this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{(RexNode)originalInputRefs.get(pos)});
                condition = this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, new RexNode[]{condition, notNull});
            }
            RexLiteral caseExpr1 = this.rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexLiteral caseExpr2 = this.rexBuilder.makeNullLiteral(caseExpr1.getType());
            RexNode when = this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{condition, caseExpr1, caseExpr2});
            gbChildProjLst.add(when);
        }
        Iterator iterator = originalGroupSet.iterator();
        while (iterator.hasNext()) {
            int pos = (Integer)iterator.next();
            gbChildProjLst.add(originalInputRefs.get(newGroupSet.indexOf(pos)));
        }
        HiveProject gbInputRel = HiveProject.create((RelNode)aggr, gbChildProjLst, null);
        ArrayList aggregateCalls = Lists.newArrayList();
        RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, this.cluster.getTypeFactory());
        for (int i2 = 0; i2 < cleanArgList.size(); ++i2) {
            AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", this.cluster, TypeInfoFactory.longTypeInfo, i2, aggFnRetType);
            aggregateCalls.add(aggregateCall);
        }
        ImmutableBitSet groupSet = ImmutableBitSet.range((int)cleanArgList.size(), (int)(cleanArgList.size() + originalGroupSet.cardinality()));
        HiveAggregate aggregate = new HiveAggregate(this.cluster, this.cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), gbInputRel, groupSet, null, aggregateCalls);
        if (map.isEmpty()) {
            return aggregate;
        }
        List originalAggrRefs = aggregate.getRowType().getFieldList().stream().map(input -> new RexInputRef(input.getIndex(), input.getType())).collect(Collectors.toList());
        ArrayList projLst = Lists.newArrayList();
        int index = 0;
        for (i = 0; i < groupSet.cardinality(); ++i) {
            projLst.add(originalAggrRefs.get(index++));
        }
        for (i = 0; i < argList.size(); ++i) {
            if (map.containsKey(i)) {
                projLst.add(originalAggrRefs.get(map.get(i)));
                continue;
            }
            projLst.add(originalAggrRefs.get(index++));
        }
        return HiveProject.create(aggregate, projLst, null);
    }

    private Aggregate createGroupingSets(Aggregate aggregate, List<List<Integer>> argList, List<List<Integer>> cleanArgList, Map<Integer, Integer> map, ImmutableBitSet groupSet) {
        ArrayList<ImmutableBitSet> origGroupSets = new ArrayList<ImmutableBitSet>();
        for (int i = 0; i < argList.size(); ++i) {
            List<Integer> list = argList.get(i);
            ImmutableBitSet bitSet = aggregate.getGroupSet().union(ImmutableBitSet.of(list));
            int prev = origGroupSets.indexOf(bitSet);
            if (prev == -1) {
                origGroupSets.add(bitSet);
                cleanArgList.add(list);
                continue;
            }
            map.put(i, prev);
        }
        origGroupSets.sort(ImmutableBitSet.COMPARATOR);
        ArrayList<AggregateCall> aggregateCalls = new ArrayList<AggregateCall>();
        AggregateCall aggCall = AggregateCall.create((SqlAggFunction)HiveGroupingID.INSTANCE, (boolean)false, (List)new ImmutableList.Builder().build(), (int)-1, (RelDataType)this.cluster.getTypeFactory().createSqlType(SqlTypeName.BIGINT), (String)HiveGroupingID.INSTANCE.getName());
        aggregateCalls.add(aggCall);
        return new HiveAggregate(this.cluster, this.cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), aggregate.getInput(), groupSet, origGroupSets, aggregateCalls);
    }

    private int getNumCountDistinctCall(Aggregate hiveAggregate) {
        int cnt = 0;
        for (AggregateCall aggCall : hiveAggregate.getAggCallList()) {
            if (!aggCall.isDistinct() || !aggCall.getAggregation().getName().equalsIgnoreCase("count")) continue;
            ++cnt;
        }
        return cnt;
    }

    private RelNode convertMonopole(Aggregate aggregate, List<Integer> argList) {
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        Aggregate distinct = HiveExpandDistinctAggregatesRule.createSelectDistinct(aggregate, argList, sourceOf);
        ArrayList newAggCalls = Lists.newArrayList((Iterable)aggregate.getAggCallList());
        HiveExpandDistinctAggregatesRule.rewriteAggCalls(newAggCalls, argList, sourceOf);
        int cardinality = aggregate.getGroupSet().cardinality();
        RelTraitSet relTraitSet = aggregate.getTraitSet();
        aggregate.getClass();
        return aggregate.copy(relTraitSet, (RelNode)distinct, false, ImmutableBitSet.range((int)cardinality), null, (List)newAggCalls);
    }

    private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        for (int i = 0; i < newAggCalls.size(); ++i) {
            AggregateCall aggCall = newAggCalls.get(i);
            if (!aggCall.isDistinct() || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = (Integer)aggCall.getArgList().get(j);
                newArgs.add(sourceOf.get(arg));
            }
            AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
            newAggCalls.set(i, newAggCall);
        }
    }

    private static Aggregate createSelectDistinct(Aggregate aggregate, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        ArrayList<Pair> projects = new ArrayList<Pair>();
        RelNode child = aggregate.getInput();
        List childFields = child.getRowType().getFieldList();
        Iterator<Integer> iterator = aggregate.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            sourceOf.put(i, projects.size());
            projects.add(RexInputRef.of2((int)i, (List)childFields));
        }
        for (Integer arg : argList) {
            if (sourceOf.get(arg) != null) continue;
            sourceOf.put(arg, projects.size());
            projects.add(RexInputRef.of2((int)arg, (List)childFields));
        }
        RelNode project = projFactory.createProject(child, Pair.left(projects), Pair.right(projects));
        return aggregate.copy(aggregate.getTraitSet(), project, false, ImmutableBitSet.range((int)projects.size()), null, (List)ImmutableList.of());
    }

    static {
        LOG = LoggerFactory.getLogger(HiveExpandDistinctAggregatesRule.class);
    }
}

