/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIntersect;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveIntersectRewriteRule
extends RelOptRule {
    public static final HiveIntersectRewriteRule INSTANCE = new HiveIntersectRewriteRule();
    protected static final Logger LOG = LoggerFactory.getLogger(HiveIntersectRewriteRule.class);

    private HiveIntersectRewriteRule() {
        super(HiveIntersectRewriteRule.operand(HiveIntersect.class, (RelOptRuleOperandChildren)HiveIntersectRewriteRule.any()));
    }

    public void onMatch(RelOptRuleCall call) {
        HiveIntersect hiveIntersect = (HiveIntersect)call.rel(0);
        RelOptCluster cluster = hiveIntersect.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        int numOfBranch = hiveIntersect.getInputs().size();
        ImmutableList.Builder bldr = new ImmutableList.Builder();
        for (int index = 0; index < numOfBranch; ++index) {
            RelNode input = (RelNode)hiveIntersect.getInputs().get(index);
            ArrayList gbChildProjLst = Lists.newArrayList();
            ArrayList groupSetPositions = Lists.newArrayList();
            for (int cInd = 0; cInd < input.getRowType().getFieldList().size(); ++cInd) {
                gbChildProjLst.add(rexBuilder.makeInputRef(input, cInd));
                groupSetPositions.add(cInd);
            }
            gbChildProjLst.add(rexBuilder.makeBigintLiteral(new BigDecimal(1)));
            HiveProject gbInputRel = null;
            try {
                gbInputRel = HiveProject.create(input, gbChildProjLst, null);
            }
            catch (CalciteSemanticException e) {
                LOG.debug(e.toString());
                throw new RuntimeException((Throwable)((Object)e));
            }
            ImmutableBitSet groupSet = ImmutableBitSet.of((Iterable)groupSetPositions);
            ArrayList aggregateCalls = Lists.newArrayList();
            RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
            AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, input.getRowType().getFieldList().size(), aggFnRetType);
            aggregateCalls.add(aggregateCall);
            HiveAggregate aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), gbInputRel, groupSet, null, aggregateCalls);
            bldr.add((Object)aggregateRel);
        }
        HiveUnion union = new HiveUnion(cluster, TraitsUtil.getDefaultTraitSet(cluster), (List<RelNode>)bldr.build());
        ArrayList groupSetPositions = Lists.newArrayList();
        int cInd = union.getRowType().getFieldList().size() - 1;
        for (int index = 0; index < union.getRowType().getFieldList().size(); ++index) {
            if (index == cInd) continue;
            groupSetPositions.add(index);
        }
        ArrayList aggregateCalls = Lists.newArrayList();
        RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
        AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
        aggregateCalls.add(aggregateCall);
        if (hiveIntersect.all) {
            aggregateCall = HiveCalciteUtil.createSingleArgAggCall("min", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
            aggregateCalls.add(aggregateCall);
        }
        ImmutableBitSet groupSet = ImmutableBitSet.of((Iterable)groupSetPositions);
        HiveAggregate aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), union, groupSet, null, aggregateCalls);
        int countInd = cInd;
        ArrayList<Object> childRexNodeLst = new ArrayList<Object>();
        RexInputRef ref = rexBuilder.makeInputRef((RelNode)aggregateRel, countInd);
        RexLiteral literal = rexBuilder.makeBigintLiteral(new BigDecimal(numOfBranch));
        childRexNodeLst.add(ref);
        childRexNodeLst.add(literal);
        ImmutableList.Builder calciteArgTypesBldr = new ImmutableList.Builder();
        calciteArgTypesBldr.add((Object)TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        calciteArgTypesBldr.add((Object)TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        RexNode factoredFilterExpr = null;
        try {
            factoredFilterExpr = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("=", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), true, false), childRexNodeLst);
        }
        catch (CalciteSemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException((Throwable)((Object)e));
        }
        HiveFilter filterRel = new HiveFilter(cluster, cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), aggregateRel, factoredFilterExpr);
        if (!hiveIntersect.all) {
            HashSet<Integer> projectOutColumnPositions = new HashSet<Integer>();
            projectOutColumnPositions.add(filterRel.getRowType().getFieldList().size() - 1);
            try {
                call.transformTo((RelNode)HiveCalciteUtil.createProjectWithoutColumn(filterRel, projectOutColumnPositions));
            }
            catch (CalciteSemanticException e) {
                LOG.debug(e.toString());
                throw new RuntimeException((Throwable)((Object)e));
            }
        }
        List originalInputRefs = Lists.transform((List)filterRel.getRowType().getFieldList(), (Function)new Function<RelDataTypeField, RexNode>(){

            public RexNode apply(RelDataTypeField input) {
                return new RexInputRef(input.getIndex(), input.getType());
            }
        });
        ArrayList copyInputRefs = new ArrayList();
        copyInputRefs.add(originalInputRefs.get(originalInputRefs.size() - 1));
        for (int i = 0; i < originalInputRefs.size() - 2; ++i) {
            copyInputRefs.add(originalInputRefs.get(i));
        }
        HiveProject srcRel = null;
        try {
            srcRel = HiveProject.create(filterRel, copyInputRefs, null);
            HiveTableFunctionScan udtf = HiveCalciteUtil.createUDTFForSetOp(cluster, srcRel);
            HashSet<Integer> projectOutColumnPositions = new HashSet<Integer>();
            projectOutColumnPositions.add(0);
            call.transformTo((RelNode)HiveCalciteUtil.createProjectWithoutColumn(udtf, projectOutColumnPositions));
        }
        catch (SemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
    }
}

