/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.AggregateJoinTransposeRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveAggregateJoinTransposeRule
extends AggregateJoinTransposeRule {
    private static final Logger LOG = LoggerFactory.getLogger(HiveAggregateJoinTransposeRule.class);
    private final boolean allowFunctions;
    private final AtomicInteger noColsMissingStats;
    private boolean costBased;
    private boolean uniqueBased;

    public HiveAggregateJoinTransposeRule(AtomicInteger noColsMissingStats, boolean costBased, boolean uniqueBased) {
        super(HiveAggregate.class, HiveJoin.class, HiveRelFactories.HIVE_BUILDER, true);
        this.costBased = costBased;
        this.uniqueBased = uniqueBased;
        this.allowFunctions = true;
        this.noColsMissingStats = noColsMissingStats;
    }

    public void onMatch(RelOptRuleCall call) {
        try {
            Object aggregation;
            Aggregate aggregate = (Aggregate)call.rel(0);
            Join join = (Join)call.rel(1);
            RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
            RelBuilder relBuilder = call.builder();
            for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
                if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
                    return;
                }
                if (aggregateCall.filterArg < 0) continue;
                return;
            }
            if (join.getJoinType() != JoinRelType.INNER) {
                return;
            }
            if (!this.allowFunctions && !aggregate.getAggCallList().isEmpty()) {
                return;
            }
            boolean groupingUnique = this.isGroupingUnique((RelNode)join, aggregate.getGroupSet());
            if (!groupingUnique && !this.costBased) {
                return;
            }
            ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
            RelMetadataQuery mq = call.getMetadataQuery();
            ImmutableBitSet keyColumns = HiveAggregateJoinTransposeRule.keyColumns(aggregateColumns, (ImmutableList<RexNode>)mq.getPulledUpPredicates((RelNode)join).pulledUpPredicates);
            ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits((RexNode)join.getCondition());
            boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
            ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
            ArrayList leftKeys = Lists.newArrayList();
            ArrayList rightKeys = Lists.newArrayList();
            ArrayList filterNulls = Lists.newArrayList();
            RexNode nonEquiConj = RelOptUtil.splitJoinCondition((RelNode)join.getLeft(), (RelNode)join.getRight(), (RexNode)join.getCondition(), (List)leftKeys, (List)rightKeys, (List)filterNulls);
            if (!nonEquiConj.isAlwaysTrue()) {
                return;
            }
            HashMap<Object, Integer> map = new HashMap<Object, Integer>();
            ArrayList<Side> sides = new ArrayList<Side>();
            int uniqueCount = 0;
            int offset = 0;
            int belowOffset = 0;
            for (int s = 0; s < 2; ++s) {
                boolean unique;
                Side side = new Side();
                RelNode joinInput = join.getInput(s);
                int fieldCount = joinInput.getRowType().getFieldCount();
                ImmutableBitSet fieldSet = ImmutableBitSet.range((int)offset, (int)(offset + fieldCount));
                ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
                for (Ord c : Ord.zip((Iterable)belowAggregateKeyNotShifted)) {
                    map.put(c.e, belowOffset + c.i);
                }
                ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
                if (!this.allowFunctions) {
                    assert (aggregate.getAggCallList().isEmpty());
                    unique = true;
                } else {
                    Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey, true);
                    boolean bl = unique = unique0 != null && unique0 != false;
                }
                if (unique) {
                    ++uniqueCount;
                    relBuilder.push(joinInput);
                    relBuilder.project((Iterable)belowAggregateKey.asList().stream().map(arg_0 -> ((RelBuilder)relBuilder).field(arg_0)).collect(Collectors.toList()));
                    side.newInput = relBuilder.build();
                } else {
                    ArrayList belowAggCalls = new ArrayList();
                    SqlSplittableAggFunction.Registry belowAggCallRegistry = HiveAggregateJoinTransposeRule.registry(belowAggCalls);
                    Mappings.IdentityMapping mapping = s == 0 ? Mappings.createIdentity((int)fieldCount) : Mappings.createShiftMapping((int)(fieldCount + offset), (int[])new int[]{0, offset, fieldCount});
                    for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
                        SqlAggFunction aggregation2 = ((AggregateCall)aggCall.e).getAggregation();
                        SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation2.unwrap(SqlSplittableAggFunction.class));
                        AggregateCall call1 = fieldSet.contains(ImmutableBitSet.of((Iterable)((AggregateCall)aggCall.e).getArgList())) ? splitter.split((AggregateCall)aggCall.e, (Mappings.TargetMapping)mapping) : splitter.other(rexBuilder.getTypeFactory(), (AggregateCall)aggCall.e);
                        if (call1 == null) continue;
                        side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register((Object)call1));
                    }
                    side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
                }
                offset += fieldCount;
                belowOffset += side.newInput.getRowType().getFieldCount();
                sides.add(side);
            }
            if (uniqueCount == 2) {
                return;
            }
            Mapping mapping = (Mapping)Mappings.target(map::get, (int)join.getRowType().getFieldCount(), (int)belowOffset);
            RexNode newCondition = RexUtil.apply((Mappings.TargetMapping)mapping, (RexNode)join.getCondition());
            relBuilder.push(((Side)sides.get((int)0)).newInput).push(((Side)sides.get((int)1)).newInput).join(join.getJoinType(), newCondition);
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
            int newLeftWidth = ((Side)sides.get((int)0)).newInput.getRowType().getFieldCount();
            ArrayList projects = new ArrayList(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
            for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
                aggregation = ((AggregateCall)aggCall.e).getAggregation();
                SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
                Integer leftSubTotal = ((Side)sides.get((int)0)).split.get(aggCall.i);
                Integer rightSubTotal = ((Side)sides.get((int)1)).split.get(aggCall.i);
                newAggCalls.add(splitter.topSplit(rexBuilder, HiveAggregateJoinTransposeRule.registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), (AggregateCall)aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
            }
            relBuilder.project(projects);
            boolean aggConvertedToProjects = false;
            if (allColumnsInAggregate) {
                ArrayList<Object> projects2 = new ArrayList<Object>();
                aggregation = Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet()).iterator();
                while (aggregation.hasNext()) {
                    int key = (Integer)aggregation.next();
                    projects2.add(relBuilder.field(key));
                }
                for (AggregateCall newAggCall : newAggCalls) {
                    SqlSplittableAggFunction splitter = (SqlSplittableAggFunction)newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                    if (splitter == null) continue;
                    RelDataType rowType = relBuilder.peek().getRowType();
                    projects2.add(splitter.singleton(rexBuilder, rowType, newAggCall));
                }
                if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                    relBuilder.project(projects2);
                    aggConvertedToProjects = true;
                }
            }
            if (!aggConvertedToProjects) {
                relBuilder.aggregate(relBuilder.groupKey(Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet()), Mappings.apply2((Mapping)mapping, (Iterable)aggregate.getGroupSets())), newAggCalls);
            }
            RelNode r = relBuilder.build();
            boolean transform = false;
            if (this.uniqueBased && aggConvertedToProjects) {
                transform = groupingUnique;
            }
            if (!transform && this.costBased) {
                RelOptCost afterCost = mq.getCumulativeCost(r);
                RelOptCost beforeCost = mq.getCumulativeCost((RelNode)aggregate);
                transform = afterCost.isLt(beforeCost);
            }
            if (transform) {
                call.transformTo(r);
            }
        }
        catch (Exception e) {
            if (this.noColsMissingStats.get() > 0) {
                LOG.warn("Missing column stats (see previous messages), skipping aggregate-join transpose in CBO");
                this.noColsMissingStats.set(0);
            }
            throw e;
        }
    }

    private boolean isGroupingUnique(RelNode input, ImmutableBitSet groups) {
        Join join;
        JoinInfo ji;
        if (groups.isEmpty()) {
            return false;
        }
        if (input instanceof HepRelVertex) {
            HepRelVertex vertex = (HepRelVertex)input;
            return this.isGroupingUnique(vertex.getCurrentRel(), groups);
        }
        RelMetadataQuery mq = input.getCluster().getMetadataQuery();
        Set uKeys = mq.getUniqueKeys(input);
        if (uKeys == null) {
            return false;
        }
        for (ImmutableBitSet u : uKeys) {
            if (!groups.contains(u)) continue;
            return true;
        }
        if (input instanceof Join && (ji = JoinInfo.of((RelNode)(join = (Join)input).getLeft(), (RelNode)join.getRight(), (RexNode)join.getCondition())).isEqui()) {
            ImmutableBitSet newGroup = groups.intersect(RelOptUtil.InputFinder.bits((RexNode)join.getCondition()));
            RelNode l = join.getLeft();
            RelNode r = join.getRight();
            int joinFieldCount = join.getRowType().getFieldCount();
            int lFieldCount = l.getRowType().getFieldCount();
            ImmutableBitSet groupL = newGroup.get(0, lFieldCount);
            ImmutableBitSet groupR = newGroup.get(lFieldCount, joinFieldCount).shift(-lFieldCount);
            if (this.isGroupingUnique(l, groupL)) {
                return true;
            }
            if (this.isGroupingUnique(r, groupR)) {
                return true;
            }
        }
        if (input instanceof Project) {
            Project project = (Project)input;
            ImmutableBitSet.Builder newGroup = ImmutableBitSet.builder();
            Iterator iterator = groups.asList().iterator();
            while (iterator.hasNext()) {
                int g = (Integer)iterator.next();
                RexNode rex = (RexNode)project.getChildExps().get(g);
                if (!(rex instanceof RexInputRef)) continue;
                RexInputRef rexInputRef = (RexInputRef)rex;
                newGroup.set(rexInputRef.getIndex());
            }
            return this.isGroupingUnique(project.getInput(), newGroup.build());
        }
        return false;
    }

    private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
        TreeMap<Integer, BitSet> equivalence = new TreeMap<Integer, BitSet>();
        for (RexNode predicate : predicates) {
            HiveAggregateJoinTransposeRule.populateEquivalences(equivalence, predicate);
        }
        ImmutableBitSet keyColumns = aggregateColumns;
        for (Integer aggregateColumn : aggregateColumns) {
            BitSet bitSet = (BitSet)equivalence.get(aggregateColumn);
            if (bitSet == null) continue;
            keyColumns = keyColumns.union(bitSet);
        }
        return keyColumns;
    }

    private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
        switch (predicate.getKind()) {
            case EQUALS: {
                RexCall call = (RexCall)predicate;
                List operands = call.getOperands();
                if (!(operands.get(0) instanceof RexInputRef)) break;
                RexInputRef ref0 = (RexInputRef)operands.get(0);
                if (!(operands.get(1) instanceof RexInputRef)) break;
                RexInputRef ref1 = (RexInputRef)operands.get(1);
                HiveAggregateJoinTransposeRule.populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
                HiveAggregateJoinTransposeRule.populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
            }
        }
    }

    private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
        BitSet bitSet = equivalence.get(i0);
        if (bitSet == null) {
            bitSet = new BitSet();
            equivalence.put(i0, bitSet);
        }
        bitSet.set(i1);
    }

    private static <E> SqlSplittableAggFunction.Registry<E> registry(final List<E> list) {
        return new SqlSplittableAggFunction.Registry<E>(){

            public int register(E e) {
                int i = list.indexOf(e);
                if (i < 0) {
                    i = list.size();
                    list.add(e);
                }
                return i;
            }
        };
    }

    private static class Side {
        final Map<Integer, Integer> split = new HashMap<Integer, Integer>();
        RelNode newInput;

        private Side() {
        }
    }
}

