/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveSemiJoinRule {
    public static final HiveProjectJoinToSemiJoinRule INSTANCE_PROJECT = new HiveProjectJoinToSemiJoinRule();
    public static final HiveAggregateJoinToSemiJoinRule INSTANCE_AGGREGATE = new HiveAggregateJoinToSemiJoinRule();
    public static final HiveProjectJoinToSemiJoinRuleSwapInputs INSTANCE_PROJECT_SWAPPED = new HiveProjectJoinToSemiJoinRuleSwapInputs();
    public static final HiveAggregateJoinToSemiJoinRuleSwapInputs INSTANCE_AGGREGATE_SWAPPED = new HiveAggregateJoinToSemiJoinRuleSwapInputs();

    private HiveSemiJoinRule() {
    }

    protected static RelNode recreateProjectOperator(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Project topProject, RelNode newInputOperator, boolean force) {
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>();
        List swappedJoinFields = newInputOperator.getRowType().getFieldList();
        for (RexNode project : topProject.getProjects()) {
            RexNode newProject = (RexNode)project.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, swappedJoinFields, swappedJoinFields, adjustments));
            newProjects.add(newProject);
        }
        return builder.push(newInputOperator).project(newProjects, (Iterable)ImmutableList.of(), force).build();
    }

    protected static RelNode recreateAggregateOperator(RelBuilder builder, int[] adjustments, Aggregate topAggregate, RelNode newInputOperator) {
        RelBuilder.GroupKey groupKey;
        builder.push(newInputOperator);
        ImmutableBitSet.Builder newGroupSet = ImmutableBitSet.builder();
        Iterator iterator = topAggregate.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int pos2 = (Integer)iterator.next();
            newGroupSet.set(pos2 + adjustments[pos2]);
        }
        if (topAggregate.getGroupType() == Aggregate.Group.SIMPLE) {
            groupKey = builder.groupKey(newGroupSet.build());
        } else {
            ArrayList<ImmutableBitSet> newGroupSets = new ArrayList<ImmutableBitSet>();
            for (ImmutableBitSet groupingSet : topAggregate.getGroupSets()) {
                ImmutableBitSet.Builder newGroupingSet = ImmutableBitSet.builder();
                Iterator iterator2 = groupingSet.iterator();
                while (iterator2.hasNext()) {
                    int pos3 = (Integer)iterator2.next();
                    newGroupingSet.set(pos3 + adjustments[pos3]);
                }
                newGroupSets.add(newGroupingSet.build());
            }
            groupKey = builder.groupKey(newGroupSet.build(), newGroupSets);
        }
        ArrayList<AggregateCall> newAggCallList = new ArrayList<AggregateCall>();
        for (AggregateCall aggregateCall : topAggregate.getAggCallList()) {
            List newArgList = aggregateCall.getArgList().stream().map(pos -> pos + adjustments[pos]).collect(Collectors.toList());
            int newFilterArg = aggregateCall.filterArg != -1 ? aggregateCall.filterArg + adjustments[aggregateCall.filterArg] : -1;
            RelCollation newCollation = aggregateCall.getCollation() != null ? RelCollations.of(aggregateCall.getCollation().getFieldCollations().stream().map(fc -> fc.withFieldIndex(fc.getFieldIndex() + adjustments[fc.getFieldIndex()])).collect(Collectors.toList())) : null;
            newAggCallList.add(aggregateCall.copy(newArgList, newFilterArg, newCollation));
        }
        return builder.push(newInputOperator).aggregate(groupKey, newAggCallList).build();
    }

    protected static class HiveAggregateJoinToSemiJoinRuleSwapInputs
    extends HiveToSemiJoinRuleSwapInputs<Aggregate> {
        protected HiveAggregateJoinToSemiJoinRuleSwapInputs() {
            super(Aggregate.class, HiveRelFactories.HIVE_BUILDER);
        }

        @Override
        protected ImmutableBitSet extractUsedFields(Aggregate aggregate) {
            return HiveCalciteUtil.extractRefs(aggregate);
        }

        @Override
        protected Aggregate recreateTopOperator(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Aggregate topAggregate, RelNode newInputOperator) {
            return (Aggregate)HiveSemiJoinRule.recreateAggregateOperator(builder, adjustments, topAggregate, newInputOperator);
        }

        @Override
        protected RelNode recreateTopOperatorUnforced(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Aggregate topAggregate, RelNode newInputOperator) {
            return HiveSemiJoinRule.recreateAggregateOperator(builder, adjustments, topAggregate, newInputOperator);
        }
    }

    protected static class HiveProjectJoinToSemiJoinRuleSwapInputs
    extends HiveToSemiJoinRuleSwapInputs<Project> {
        protected HiveProjectJoinToSemiJoinRuleSwapInputs() {
            super(Project.class, HiveRelFactories.HIVE_BUILDER);
        }

        @Override
        protected ImmutableBitSet extractUsedFields(Project project) {
            return RelOptUtil.InputFinder.bits((List)project.getChildExps(), null);
        }

        @Override
        protected Project recreateTopOperator(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Project topProject, RelNode newInputOperator) {
            return (Project)HiveSemiJoinRule.recreateProjectOperator(builder, rexBuilder, adjustments, topProject, newInputOperator, true);
        }

        @Override
        protected RelNode recreateTopOperatorUnforced(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Project topProject, RelNode newInputOperator) {
            return HiveSemiJoinRule.recreateProjectOperator(builder, rexBuilder, adjustments, topProject, newInputOperator, false);
        }
    }

    private static abstract class HiveToSemiJoinRuleSwapInputs<T extends RelNode>
    extends HiveSemiJoinRuleBase<T> {
        protected HiveToSemiJoinRuleSwapInputs(Class<T> clazz, RelBuilderFactory relBuilder) {
            super(HiveToSemiJoinRuleSwapInputs.operand(clazz, (RelOptRuleOperand)HiveToSemiJoinRuleSwapInputs.operand(Join.class, (RelOptRuleOperand)HiveToSemiJoinRuleSwapInputs.operand(Aggregate.class, (RelOptRuleOperand)HiveToSemiJoinRuleSwapInputs.operand(RelNode.class, (RelOptRuleOperandChildren)HiveToSemiJoinRuleSwapInputs.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveToSemiJoinRuleSwapInputs.operand(RelNode.class, (RelOptRuleOperandChildren)HiveToSemiJoinRuleSwapInputs.any())}), (RelOptRuleOperand[])new RelOptRuleOperand[0]), relBuilder);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            ImmutableBitSet leftBits;
            RelNode topOperator = call.rel(0);
            Join join = (Join)call.rel(1);
            if (join.isSemiJoin()) {
                return;
            }
            Aggregate aggregate = (Aggregate)call.rel(2);
            RelNode aggregateInput = call.rel(3);
            RelNode right = call.rel(4);
            JoinInfo joinInfo = join.analyzeCondition();
            if (!joinInfo.isEqui()) {
                return;
            }
            if (!joinInfo.leftSet().equals((Object)ImmutableBitSet.range((int)aggregate.getGroupCount()))) {
                return;
            }
            ImmutableBitSet topRefs = this.extractUsedFields(topOperator);
            if (topRefs.intersects(leftBits = ImmutableBitSet.range((int)0, (int)join.getLeft().getRowType().getFieldCount()))) {
                return;
            }
            RelNode swappedTopOperator = this.swapInputs(join, topOperator, call.builder());
            Join swappedJoin = (Join)swappedTopOperator.getInput(0);
            ImmutableBitSet swappedTopRefs = this.extractUsedFields(swappedTopOperator);
            this.perform(call, swappedTopRefs, swappedTopOperator, swappedJoin, right, aggregate, aggregateInput);
        }

        protected T swapInputs(Join join, T topOperator, RelBuilder builder) {
            RexBuilder rexBuilder = join.getCluster().getRexBuilder();
            int rightInputSize = join.getRight().getRowType().getFieldCount();
            int leftInputSize = join.getLeft().getRowType().getFieldCount();
            List joinFields = join.getRowType().getFieldList();
            int[] adjustments = new int[joinFields.size()];
            for (int i = 0; i < joinFields.size(); ++i) {
                adjustments[i] = i < leftInputSize ? rightInputSize : -leftInputSize;
            }
            RexNode newJoinCond = (RexNode)join.getCondition().accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, joinFields, joinFields, adjustments));
            RelNode swappedJoin = builder.push(join.getRight()).push(join.getLeft()).join(join.getJoinType(), newJoinCond).build();
            return this.recreateTopOperator(builder, rexBuilder, adjustments, topOperator, swappedJoin);
        }
    }

    protected static class HiveAggregateJoinToSemiJoinRule
    extends HiveSemiJoinRuleBase<Aggregate> {
        protected HiveAggregateJoinToSemiJoinRule() {
            super(Aggregate.class, HiveRelFactories.HIVE_BUILDER);
        }

        @Override
        protected ImmutableBitSet extractUsedFields(Aggregate aggregate) {
            return HiveCalciteUtil.extractRefs(aggregate);
        }

        @Override
        protected Aggregate recreateTopOperator(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Aggregate topAggregate, RelNode newInputOperator) {
            return (Aggregate)HiveSemiJoinRule.recreateAggregateOperator(builder, adjustments, topAggregate, newInputOperator);
        }

        @Override
        protected RelNode recreateTopOperatorUnforced(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Aggregate topAggregate, RelNode newInputOperator) {
            return HiveSemiJoinRule.recreateAggregateOperator(builder, adjustments, topAggregate, newInputOperator);
        }
    }

    protected static class HiveProjectJoinToSemiJoinRule
    extends HiveSemiJoinRuleBase<Project> {
        protected HiveProjectJoinToSemiJoinRule() {
            super(Project.class, HiveRelFactories.HIVE_BUILDER);
        }

        @Override
        protected ImmutableBitSet extractUsedFields(Project project) {
            return RelOptUtil.InputFinder.bits((List)project.getChildExps(), null);
        }

        @Override
        protected Project recreateTopOperator(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Project topProject, RelNode newInputOperator) {
            return (Project)HiveSemiJoinRule.recreateProjectOperator(builder, rexBuilder, adjustments, topProject, newInputOperator, true);
        }

        @Override
        protected RelNode recreateTopOperatorUnforced(RelBuilder builder, RexBuilder rexBuilder, int[] adjustments, Project topProject, RelNode newInputOperator) {
            return HiveSemiJoinRule.recreateProjectOperator(builder, rexBuilder, adjustments, topProject, newInputOperator, false);
        }
    }

    private static abstract class HiveSemiJoinRuleBase<T extends RelNode>
    extends RelOptRule {
        protected static final Logger LOG = LoggerFactory.getLogger(HiveSemiJoinRuleBase.class);

        protected HiveSemiJoinRuleBase(Class<T> clazz, RelBuilderFactory relBuilder) {
            super(HiveSemiJoinRuleBase.operand(clazz, (RelOptRuleOperand)HiveSemiJoinRuleBase.operand(Join.class, (RelOptRuleOperand)HiveSemiJoinRuleBase.operand(RelNode.class, (RelOptRuleOperandChildren)HiveSemiJoinRuleBase.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveSemiJoinRuleBase.operand(Aggregate.class, (RelOptRuleOperand)HiveSemiJoinRuleBase.operand(RelNode.class, (RelOptRuleOperandChildren)HiveSemiJoinRuleBase.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0])}), (RelOptRuleOperand[])new RelOptRuleOperand[0]), relBuilder, null);
        }

        protected HiveSemiJoinRuleBase(RelOptRuleOperand operand, RelBuilderFactory relBuilder) {
            super(operand, relBuilder, null);
        }

        public void onMatch(RelOptRuleCall call) {
            RelNode topOperator = call.rel(0);
            Join join = (Join)call.rel(1);
            if (join.isSemiJoin()) {
                return;
            }
            RelNode left = call.rel(2);
            Aggregate aggregate = (Aggregate)call.rel(3);
            RelNode aggregateInput = call.rel(4);
            ImmutableBitSet topRefs = this.extractUsedFields(topOperator);
            this.perform(call, topRefs, topOperator, join, left, aggregate, aggregateInput);
        }

        private boolean needProject(RelNode input, RelNode aggregate) {
            return input instanceof Join || input.getRowType().getFieldCount() != aggregate.getRowType().getFieldCount();
        }

        protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs, T topOperator, Join join, RelNode left, Aggregate aggregate, RelNode aggregateInput) {
            ImmutableIntList leftKeys;
            RelNode newRight;
            LOG.debug("Matched HiveSemiJoinRule");
            RelOptCluster cluster = join.getCluster();
            RexBuilder rexBuilder = cluster.getRexBuilder();
            ImmutableBitSet rightBits = ImmutableBitSet.range((int)left.getRowType().getFieldCount(), (int)join.getRowType().getFieldCount());
            if (topRefs.intersects(rightBits)) {
                return;
            }
            JoinInfo joinInfo = join.analyzeCondition();
            if (!joinInfo.rightSet().equals((Object)ImmutableBitSet.range((int)aggregate.getGroupCount()))) {
                return;
            }
            if (!joinInfo.isEqui()) {
                return;
            }
            if (join.getJoinType() == JoinRelType.LEFT) {
                call.transformTo(topOperator.copy(topOperator.getTraitSet(), (List)ImmutableList.of((Object)left)));
                return;
            }
            if (join.getJoinType() != JoinRelType.INNER) {
                return;
            }
            LOG.debug("All conditions matched for HiveSemiJoinRule. Going to apply transformation.");
            ImmutableBitSet leftNeededRefs = topRefs.union(ImmutableBitSet.of((ImmutableIntList)joinInfo.leftKeys));
            boolean updateLeft = leftNeededRefs.cardinality() != left.getRowType().getFieldCount();
            RelNode newLeft = updateLeft ? this.buildProjectLeftInput(left, leftNeededRefs, rexBuilder, call.builder()) : left;
            RelNode relNode = newRight = this.needProject(aggregateInput, (RelNode)aggregate) ? this.buildProjectRightInput(aggregate, rexBuilder, call.builder()) : aggregateInput;
            if (updateLeft) {
                ArrayList<Integer> newLeftKeys = new ArrayList<Integer>();
                Iterator iterator = joinInfo.leftKeys.iterator();
                while (iterator.hasNext()) {
                    int pos = (Integer)iterator.next();
                    newLeftKeys.add(leftNeededRefs.indexOf(pos));
                }
                leftKeys = ImmutableIntList.copyOf(newLeftKeys);
            } else {
                leftKeys = joinInfo.leftKeys;
            }
            RexNode newCondition = RelOptUtil.createEquiJoinCondition((RelNode)newLeft, (List)leftKeys, (RelNode)newRight, (List)joinInfo.rightKeys, (RexBuilder)rexBuilder);
            RelNode semi = call.builder().push(newLeft).push(newRight).semiJoin(new RexNode[]{newCondition}).build();
            int[] adjustments = new int[left.getRowType().getFieldCount()];
            for (int i = 0; i < adjustments.length; ++i) {
                adjustments[i] = leftNeededRefs.indexOf(i) - i;
            }
            call.transformTo(this.recreateTopOperatorUnforced(call.builder(), rexBuilder, adjustments, topOperator, semi));
        }

        private RelNode buildProjectLeftInput(RelNode node, ImmutableBitSet neededRefs, RexBuilder rexBuilder, RelBuilder builder) {
            ArrayList<RexInputRef> exprs = new ArrayList<RexInputRef>();
            Iterator iterator = neededRefs.iterator();
            while (iterator.hasNext()) {
                int pos = (Integer)iterator.next();
                exprs.add(rexBuilder.makeInputRef(node, pos));
            }
            return builder.push(node).project(exprs).build();
        }

        private RelNode buildProjectRightInput(Aggregate aggregate, RexBuilder rexBuilder, RelBuilder relBuilder) {
            assert (aggregate.getGroupType() == Aggregate.Group.SIMPLE && aggregate.getAggCallList().isEmpty());
            RelNode input = aggregate.getInput();
            List groupingKeys = aggregate.getGroupSet().asList();
            ArrayList<RexInputRef> projects = new ArrayList<RexInputRef>();
            for (Integer keys : groupingKeys) {
                projects.add(rexBuilder.makeInputRef(input, keys.intValue()));
            }
            return relBuilder.push(aggregate.getInput()).project(projects).build();
        }

        protected abstract ImmutableBitSet extractUsedFields(T var1);

        protected abstract T recreateTopOperator(RelBuilder var1, RexBuilder var2, int[] var3, T var4, RelNode var5);

        protected abstract RelNode recreateTopOperatorUnforced(RelBuilder var1, RexBuilder var2, int[] var3, T var4, RelNode var5);
    }
}

