/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;

public class HiveAggregateReduceFunctionsRule
extends RelOptRule {
    public static final HiveAggregateReduceFunctionsRule INSTANCE = new HiveAggregateReduceFunctionsRule();

    public HiveAggregateReduceFunctionsRule() {
        super(HiveAggregateReduceFunctionsRule.operand(HiveAggregate.class, (RelOptRuleOperandChildren)HiveAggregateReduceFunctionsRule.any()), HiveRelFactories.HIVE_BUILDER, null);
    }

    public boolean matches(RelOptRuleCall call) {
        if (!super.matches(call)) {
            return false;
        }
        Aggregate oldAggRel = (Aggregate)call.rels[0];
        return this.containsAvgStddevVarCall(oldAggRel.getAggCallList());
    }

    public void onMatch(RelOptRuleCall ruleCall) {
        Aggregate oldAggRel = (Aggregate)ruleCall.rels[0];
        this.reduceAggs(ruleCall, oldAggRel);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
        for (AggregateCall call : aggCallList) {
            if (!this.isReducible(call.getAggregation().getKind())) continue;
            return true;
        }
        return false;
    }

    private boolean isReducible(SqlKind kind) {
        if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) {
            return true;
        }
        return kind == SqlKind.SUM0;
    }

    private void reduceAggs(RelOptRuleCall ruleCall, Aggregate oldAggRel) {
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        List oldCalls = oldAggRel.getAggCallList();
        int groupCount = oldAggRel.getGroupCount();
        int indicatorCount = oldAggRel.getIndicatorCount();
        ArrayList newCalls = Lists.newArrayList();
        HashMap aggCallMapping = Maps.newHashMap();
        ArrayList projList = Lists.newArrayList();
        for (int i = 0; i < groupCount + indicatorCount; ++i) {
            projList.add(rexBuilder.makeInputRef(this.getFieldType((RelNode)oldAggRel, i), i));
        }
        RelBuilder relBuilder = ruleCall.builder();
        relBuilder.push(oldAggRel.getInput());
        ArrayList<RexNode> inputExprs = new ArrayList<RexNode>((Collection<RexNode>)relBuilder.fields());
        for (AggregateCall oldCall : oldCalls) {
            projList.add(this.reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
        }
        int extraArgCount = inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
        if (extraArgCount > 0) {
            relBuilder.project(inputExprs, (Iterable)CompositeList.of((List)relBuilder.peek().getRowType().getFieldNames(), Collections.nCopies(extraArgCount, null)));
        }
        this.newAggregateRel(relBuilder, oldAggRel, newCalls);
        relBuilder.project((Iterable)projList, (Iterable)oldAggRel.getRowType().getFieldNames()).convert(oldAggRel.getRowType(), false);
        ruleCall.transformTo(relBuilder.build());
    }

    private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        SqlKind kind = oldCall.getAggregation().getKind();
        if (this.isReducible(kind)) {
            switch (kind) {
                case SUM0: {
                    return this.reduceSum0(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
                }
                case AVG: {
                    return this.reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
                }
                case STDDEV_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
                }
                case STDDEV_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
                }
            }
            throw Util.unexpected((Enum)kind);
        }
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int nGroups = oldAggRel.getGroupCount();
        List oldArgTypes = SqlTypeUtil.projectTypes((RelDataType)oldAggRel.getInput().getRowType(), (List)oldCall.getArgList());
        oldAggRel.getClass();
        return rexBuilder.addAggCall(oldCall, nGroups, false, newCalls, aggCallMapping, oldArgTypes);
    }

    private AggregateCall createAggregateCallWithBinding(RelDataTypeFactory typeFactory, SqlAggFunction aggFunction, RelDataType operandType, Aggregate oldAggRel, AggregateCall oldCall, int argOrdinal) {
        Aggregate.AggCallBinding binding = new Aggregate.AggCallBinding(typeFactory, aggFunction, (List)ImmutableList.of((Object)operandType), oldAggRel.getGroupCount(), oldCall.filterArg >= 0);
        return AggregateCall.create((SqlAggFunction)aggFunction, (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (List)ImmutableIntList.of((int[])new int[]{argOrdinal}), (int)oldCall.filterArg, (RelDataType)aggFunction.inferReturnType((SqlOperatorBinding)binding), null);
    }

    private RexNode reduceSum0(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        int nGroups = oldAggRel.getGroupCount();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        int iAvgInput = (Integer)oldCall.getArgList().get(0);
        RelDataType sum0InputType = typeFactory.createTypeWithNullability(this.getFieldType(oldAggRel.getInput(), iAvgInput), true);
        RelDataType sumReturnType = this.getSumReturnType(rexBuilder.getTypeFactory(), sum0InputType);
        AggregateCall sumCall = AggregateCall.create((SqlAggFunction)new HiveSqlSumAggFunction(oldCall.isDistinct(), (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)sumReturnType), oldCall.getAggregation().getOperandTypeInference(), oldCall.getAggregation().getOperandTypeChecker()), (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (List)oldCall.getArgList(), (int)oldCall.filterArg, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), null, null);
        oldAggRel.getClass();
        RexNode refSum = rexBuilder.addAggCall(sumCall, nGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)sum0InputType));
        refSum = rexBuilder.ensureType(oldCall.getType(), refSum, true);
        RexNode coalesce = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.COALESCE, new RexNode[]{refSum, rexBuilder.makeZeroLiteral(refSum.getType())});
        return rexBuilder.makeCast(oldCall.getType(), coalesce);
    }

    private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        int nGroups = oldAggRel.getGroupCount();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        int iAvgInput = (Integer)oldCall.getArgList().get(0);
        RelDataType avgInputType = typeFactory.createTypeWithNullability(this.getFieldType(oldAggRel.getInput(), iAvgInput), true);
        RelDataType sumReturnType = this.getSumReturnType(rexBuilder.getTypeFactory(), avgInputType);
        AggregateCall sumCall = AggregateCall.create((SqlAggFunction)new HiveSqlSumAggFunction(oldCall.isDistinct(), (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)sumReturnType), oldCall.getAggregation().getOperandTypeInference(), oldCall.getAggregation().getOperandTypeChecker()), (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (List)oldCall.getArgList(), (int)oldCall.filterArg, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), null, null);
        RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
        AggregateCall countCall = AggregateCall.create((SqlAggFunction)new HiveSqlCountAggFunction(oldCall.isDistinct(), (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)countRetType), oldCall.getAggregation().getOperandTypeInference(), oldCall.getAggregation().getOperandTypeChecker()), (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (List)oldCall.getArgList(), (int)oldCall.filterArg, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), (RelDataType)countRetType, null);
        oldAggRel.getClass();
        RexNode numeratorRef = rexBuilder.addAggCall(sumCall, nGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)avgInputType));
        oldAggRel.getClass();
        RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)avgInputType));
        if (numeratorRef.getType().getSqlTypeName() != SqlTypeName.DECIMAL) {
            numeratorRef = rexBuilder.ensureType(oldCall.getType(), numeratorRef, true);
        }
        RexNode divideRef = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{numeratorRef, denominatorRef});
        return rexBuilder.makeCast(oldCall.getType(), divideRef);
    }

    private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        RexNode div;
        RexNode denominator;
        int nGroups = oldAggRel.getGroupCount();
        RelOptCluster cluster = oldAggRel.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        assert (oldCall.getArgList().size() == 1) : oldCall.getArgList();
        int argOrdinal = (Integer)oldCall.getArgList().get(0);
        RelDataType argOrdinalType = this.getFieldType(oldAggRel.getInput(), argOrdinal);
        RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), true);
        RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
        int argRefOrdinal = HiveAggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argRef);
        RelDataType sumReturnType = this.getSumReturnType(rexBuilder.getTypeFactory(), argRef.getType());
        RexNode argSquared = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{argRef, argRef});
        int argSquaredOrdinal = HiveAggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argSquared);
        RelDataType sumSquaredReturnType = this.getSumReturnType(rexBuilder.getTypeFactory(), argSquared.getType());
        AggregateCall sumArgSquaredAggCall = this.createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(oldCall.isDistinct(), (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)sumSquaredReturnType), InferTypes.explicit(Collections.singletonList(argSquared.getType())), oldCall.getAggregation().getOperandTypeChecker()), argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
        oldAggRel.getClass();
        RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)sumArgSquaredAggCall.getType()));
        AggregateCall sumArgAggCall = AggregateCall.create((SqlAggFunction)new HiveSqlSumAggFunction(oldCall.isDistinct(), (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)sumReturnType), InferTypes.explicit(Collections.singletonList(argOrdinalType)), oldCall.getAggregation().getOperandTypeChecker()), (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (List)ImmutableIntList.of((int[])new int[]{argRefOrdinal}), (int)oldCall.filterArg, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), null, null);
        oldAggRel.getClass();
        RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)sumArgAggCall.getType()));
        RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
        RexNode sumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sumArgCast, sumArgCast});
        RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
        AggregateCall countArgAggCall = AggregateCall.create((SqlAggFunction)new HiveSqlCountAggFunction(oldCall.isDistinct(), (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)countRetType), oldCall.getAggregation().getOperandTypeInference(), oldCall.getAggregation().getOperandTypeChecker()), (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (List)oldCall.getArgList(), (int)oldCall.filterArg, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), (RelDataType)countRetType, null);
        oldAggRel.getClass();
        RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)argOrdinalType));
        RexNode avgSumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{sumSquaredArg, countArg});
        RexNode diff = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{sumArgSquared, avgSumSquaredArg});
        if (biased) {
            denominator = countArg;
        } else {
            RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexNode nul = rexBuilder.makeCast(countArg.getType(), (RexNode)rexBuilder.constantNull());
            RexNode countMinusOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{countArg, one});
            RexNode countEqOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{countArg, one});
            denominator = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{countEqOne, nul, countMinusOne});
        }
        RexNode result = div = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{diff, denominator});
        if (sqrt) {
            RexLiteral half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
            result = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.POWER, new RexNode[]{div, half});
        }
        return rexBuilder.makeCast(oldCall.getType(), result);
    }

    private static int lookupOrAdd(List<RexNode> list, RexNode element) {
        for (int ordinal = 0; ordinal < list.size(); ++ordinal) {
            if (!list.get(ordinal).toString().equals(element.toString())) continue;
            return ordinal;
        }
        list.add(element);
        return list.size() - 1;
    }

    protected void newAggregateRel(RelBuilder relBuilder, Aggregate oldAggregate, List<AggregateCall> newCalls) {
        relBuilder.aggregate(relBuilder.groupKey(oldAggregate.getGroupSet(), oldAggregate.getGroupSets()), newCalls);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        RelDataTypeField inputField = (RelDataTypeField)relNode.getRowType().getFieldList().get(i);
        return inputField.getType();
    }

    private RelDataType getSumReturnType(RelDataTypeFactory typeFactory, RelDataType inputType) {
        switch (inputType.getSqlTypeName()) {
            case TINYINT: 
            case SMALLINT: 
            case INTEGER: 
            case BIGINT: {
                return TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
            }
            case TIMESTAMP: 
            case FLOAT: 
            case DOUBLE: 
            case VARCHAR: 
            case CHAR: {
                return TypeConverter.convert(TypeInfoFactory.doubleTypeInfo, typeFactory);
            }
            case DECIMAL: {
                return typeFactory.getTypeSystem().deriveSumType(typeFactory, inputType);
            }
        }
        return null;
    }
}

