/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Stack;
import java.util.stream.Collectors;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.exec.AbstractMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.io.AcidUtils;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.SemanticNodeProcessor;
import org.apache.hadoop.hive.ql.lib.SemanticRule;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ColStatistics;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.GroupByDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PlanUtils;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GroupingSetOptimizer
extends Transform {
    private static final Logger LOG = LoggerFactory.getLogger(GroupingSetOptimizer.class);

    @Override
    public ParseContext transform(ParseContext pCtx) throws SemanticException {
        LinkedHashMap<SemanticRule, SemanticNodeProcessor> testRules = new LinkedHashMap<SemanticRule, SemanticNodeProcessor>();
        testRules.put(new RuleRegExp("GBY", GroupByOperator.getOperatorName() + "%"), new GroupingSetProcessor());
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, testRules, new GroupingSetProcessorContext(pCtx.getConf()));
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pCtx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pCtx;
    }

    private static class GroupingSetProcessor
    implements SemanticNodeProcessor {
        private GroupingSetProcessor() {
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            GroupByOperator gby = (GroupByOperator)nd;
            GroupingSetProcessorContext context = (GroupingSetProcessorContext)procCtx;
            if (!this.isGroupByFeasible(gby, context)) {
                return null;
            }
            Operator<OperatorDesc> parentOp = gby.getParentOperators().get(0);
            if (!this.isParentOpFeasible(parentOp)) {
                return null;
            }
            String partitionCol = this.selectPartitionColumn(gby, parentOp);
            if (partitionCol == null) {
                return null;
            }
            LOG.info("Applying GroupingSetOptimization: partitioning the input data of {} by {}", (Object)gby, (Object)partitionCol);
            ReduceSinkOperator rs = this.createReduceSink(parentOp, partitionCol, context);
            parentOp.removeChild(gby);
            SelectOperator sel = this.createSelect(parentOp.getSchema().getSignature(), partitionCol, rs);
            sel.setChildOperators(Arrays.asList(gby));
            gby.setParentOperators(Arrays.asList(sel));
            return null;
        }

        private boolean isGroupByFeasible(GroupByOperator gby, GroupingSetProcessorContext context) {
            if (!((GroupByDesc)gby.getConf()).isGroupingSetsPresent() || gby.getStatistics() == null) {
                return false;
            }
            if (gby.getStatistics().getNumRows() < context.groupingSetThreshold) {
                LOG.debug("Skip grouping-set optimization on a small operator: {}", (Object)gby);
                return false;
            }
            if (gby.getParentOperators().size() != 1) {
                LOG.debug("Skip grouping-set optimization on a operator with multiple parent operators: {}", (Object)gby);
                return false;
            }
            return true;
        }

        private boolean isParentOpFeasible(Operator<?> parentOp) {
            ReduceSinkOperator rs = null;
            Operator<Object> curOp = parentOp;
            while (true) {
                if (curOp instanceof ReduceSinkOperator) {
                    rs = (ReduceSinkOperator)curOp;
                    break;
                }
                if (curOp.getParentOperators() == null) break;
                if (curOp.getParentOperators().size() == 1) {
                    curOp = curOp.getParentOperators().get(0);
                    continue;
                }
                if (!(curOp instanceof AbstractMapJoinOperator)) break;
                MapJoinDesc desc = (MapJoinDesc)((AbstractMapJoinOperator)curOp).getConf();
                curOp = curOp.getParentOperators().get(desc.getPosBigTable());
            }
            if (rs == null) {
                return true;
            }
            if (((ReduceSinkDesc)rs.getConf()).getPartitionCols() != null && ((ReduceSinkDesc)rs.getConf()).getPartitionCols().size() > 0) {
                LOG.debug("Skip grouping-set optimization in order not to introduce possibly redundant shuffle.");
                return false;
            }
            return true;
        }

        private String selectPartitionColumn(GroupByOperator gby, Operator<?> parentOp) {
            if (parentOp.getSchema() == null || parentOp.getSchema().getSignature() == null) {
                LOG.debug("Skip grouping-set optimization as the parent operator {} does not provide signature", parentOp);
                return null;
            }
            if (parentOp.getStatistics() == null || parentOp.getStatistics().getNumRows() <= 0L || parentOp.getStatistics().getColumnStats() == null) {
                LOG.debug("Skip grouping-set optimization as the parent operator {} does not provide statistics", parentOp);
                return null;
            }
            if (parentOp.getStatistics().getNumRows() > gby.getStatistics().getNumRows()) {
                LOG.debug("Skip grouping-set optimization as the parent operator {} emits more rows than {}", parentOp, (Object)gby);
                return null;
            }
            ArrayList<String> colNamesInSignature = new ArrayList<String>();
            for (ColumnInfo pColInfo : parentOp.getSchema().getSignature()) {
                colNamesInSignature.add(pColInfo.getInternalName());
            }
            List<Integer> groupingSetKeys = this.listGroupingSetKeyPositions(((GroupByDesc)gby.getConf()).getListGroupingSets());
            HashSet<String> candidates = new HashSet<String>();
            for (Integer groupingSetKeyPosition : groupingSetKeys) {
                ExprNodeDesc key = ((GroupByDesc)gby.getConf()).getKeys().get(groupingSetKeyPosition);
                if (!(key instanceof ExprNodeColumnDesc)) continue;
                candidates.add(((ExprNodeColumnDesc)key).getColumn());
            }
            candidates.retainAll(colNamesInSignature);
            List columnStatistics = new ArrayList<ColStatistics>(parentOp.getStatistics().getColumnStats()).stream().filter(cs -> cs.getCountDistint() > 0L).sorted(Comparator.comparingLong(ColStatistics::getCountDistint).reversed()).collect(Collectors.toList());
            String partitionCol = null;
            for (ColStatistics col : columnStatistics) {
                String colName = col.getColumnName();
                if (!parentOp.getColumnExprMap().containsKey(colName) || !candidates.contains(colName)) continue;
                partitionCol = colName;
                break;
            }
            if (partitionCol == null) {
                LOG.debug("Skip grouping-set optimization as there is no feasible column in parent operator {}.", parentOp);
            }
            return partitionCol;
        }

        private ReduceSinkOperator createReduceSink(Operator<?> parentOp, String partitionColName, GroupingSetProcessorContext context) {
            HashMap<String, ExprNodeDesc> colExprMap = new HashMap<String, ExprNodeDesc>();
            ArrayList<ExprNodeDesc> keyColumns = new ArrayList<ExprNodeDesc>();
            ArrayList<String> keyColumnNames = new ArrayList<String>();
            ArrayList<ExprNodeDesc> valueColumns = new ArrayList<ExprNodeDesc>();
            ArrayList<String> valueColumnNames = new ArrayList<String>();
            ArrayList<ColumnInfo> signature = new ArrayList<ColumnInfo>();
            ArrayList<ExprNodeDesc> partCols = new ArrayList<ExprNodeDesc>();
            for (ColumnInfo pColInfo : parentOp.getSchema().getSignature()) {
                String cColName;
                ColumnInfo cColInfo = new ColumnInfo(pColInfo);
                String pColName = pColInfo.getInternalName();
                if (pColName.equals(partitionColName)) {
                    keyColumnNames.add(pColName);
                    cColName = Utilities.ReduceField.KEY + "." + pColName;
                    cColInfo.setInternalName(cColName);
                    signature.add(cColInfo);
                    ExprNodeColumnDesc keyExpr = new ExprNodeColumnDesc(pColInfo);
                    keyColumns.add(keyExpr);
                    colExprMap.put(cColName, keyExpr);
                    partCols.add(keyExpr);
                    continue;
                }
                valueColumnNames.add(pColName);
                cColName = Utilities.ReduceField.VALUE + "." + pColName;
                cColInfo.setInternalName(cColName);
                signature.add(cColInfo);
                ExprNodeColumnDesc valueExpr = new ExprNodeColumnDesc(pColInfo);
                valueColumns.add(valueExpr);
                colExprMap.put(cColName, valueExpr);
            }
            List<FieldSchema> valueFields = PlanUtils.getFieldSchemasFromColumnList(valueColumns, valueColumnNames, 0, "");
            TableDesc valueTable = PlanUtils.getReduceValueTableDesc(valueFields);
            List<FieldSchema> keyFields = PlanUtils.getFieldSchemasFromColumnList(keyColumns, keyColumnNames, 0, "");
            TableDesc keyTable = PlanUtils.getReduceKeyTableDesc(keyFields, "+", "z");
            ArrayList<List<Integer>> distinctColumnIndices = new ArrayList<List<Integer>>();
            int numReducers = Utilities.estimateReducers(parentOp.getStatistics().getDataSize(), context.bytesPerReducer, context.maxReducers, false);
            ReduceSinkDesc rsConf = new ReduceSinkDesc(keyColumns, keyColumns.size(), valueColumns, keyColumnNames, distinctColumnIndices, valueColumnNames, -1, partCols, numReducers, keyTable, valueTable, AcidUtils.Operation.NOT_ACID);
            ReduceSinkOperator rs = (ReduceSinkOperator)OperatorFactory.getAndMakeChild(rsConf, new RowSchema(signature), parentOp, new Operator[0]);
            rs.setColumnExprMap(colExprMap);
            rsConf.setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.UNIFORM, ReduceSinkDesc.ReducerTraits.AUTOPARALLEL));
            return rs;
        }

        private SelectOperator createSelect(List<ColumnInfo> signature, String partitionColName, Operator<?> parentOp) {
            ArrayList<String> selColNames = new ArrayList<String>();
            ArrayList<ExprNodeDesc> selColumns = new ArrayList<ExprNodeDesc>();
            ArrayList<ColumnInfo> selSignature = new ArrayList<ColumnInfo>();
            HashMap<String, ExprNodeDesc> colExprMap = new HashMap<String, ExprNodeDesc>();
            for (ColumnInfo pColInfo : signature) {
                String origColName = pColInfo.getInternalName();
                String rsColName = origColName.equals(partitionColName) ? Utilities.ReduceField.KEY + "." + origColName : Utilities.ReduceField.VALUE + "." + origColName;
                ColumnInfo selColInfo = new ColumnInfo(pColInfo);
                ExprNodeColumnDesc selExpr = new ExprNodeColumnDesc(pColInfo.getType(), rsColName, null, false);
                selSignature.add(selColInfo);
                selColumns.add(selExpr);
                selColNames.add(origColName);
                colExprMap.put(origColName, selExpr);
            }
            SelectDesc selConf = new SelectDesc(selColumns, selColNames);
            SelectOperator sel = (SelectOperator)OperatorFactory.getAndMakeChild(selConf, new RowSchema(selSignature), parentOp, new Operator[0]);
            sel.setColumnExprMap(colExprMap);
            return sel;
        }

        private List<Integer> listGroupingSetKeyPositions(List<Long> groupingSets) {
            long acc = 0L;
            for (Long groupingSet : groupingSets) {
                acc |= groupingSet.longValue();
            }
            BitSet bitset = BitSet.valueOf(new long[]{acc});
            ArrayList<Integer> ret = new ArrayList<Integer>();
            int i = bitset.nextSetBit(0);
            while (i >= 0) {
                ret.add(i);
                i = bitset.nextSetBit(i + 1);
            }
            return ret;
        }
    }

    private static class GroupingSetProcessorContext
    implements NodeProcessorCtx {
        public final long bytesPerReducer;
        public final int maxReducers;
        public final long groupingSetThreshold;

        public GroupingSetProcessorContext(HiveConf hiveConf) {
            this.bytesPerReducer = hiveConf.getLongVar(HiveConf.ConfVars.BYTES_PER_REDUCER);
            this.maxReducers = hiveConf.getIntVar(HiveConf.ConfVars.MAX_REDUCERS);
            this.groupingSetThreshold = hiveConf.getLongVar(HiveConf.ConfVars.HIVE_OPTIMIZE_GROUPING_SET_THRESHOLD);
        }
    }
}

