/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import com.google.common.base.Preconditions;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.FilterOperator;
import org.apache.hadoop.hive.ql.exec.FunctionUtils;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.OperatorUtils;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.io.AcidUtils;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.GenTezUtils;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.RuntimeValuesInfo;
import org.apache.hadoop.hive.ql.parse.SemanticAnalyzer;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.SemiJoinBranchInfo;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.plan.DynamicValue;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
import org.apache.hadoop.hive.ql.plan.ExprNodeDynamicValueDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.GroupByDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PlanUtils;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.ql.plan.TableScanDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMin;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFInBloomFilter;
import org.apache.hadoop.hive.ql.util.NullOrdering;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;

public class SemiJoinReductionMerge
extends Transform {
    @Override
    public ParseContext transform(ParseContext parseContext) throws SemanticException {
        LinkedHashMap<ReduceSinkOperator, SemiJoinBranchInfo> allSemijoins = parseContext.getRsToSemiJoinBranchInfo();
        if (allSemijoins.isEmpty()) {
            return parseContext;
        }
        HiveConf hiveConf = parseContext.getConf();
        for (Map.Entry<SJSourceTarget, List<ReduceSinkOperator>> sjMergeCandidate : SemiJoinReductionMerge.createMergeCandidates(allSemijoins)) {
            List<ReduceSinkOperator> sjBranches = sjMergeCandidate.getValue();
            if (sjBranches.size() < 2) continue;
            ArrayList<SelectOperator> selOps = new ArrayList<SelectOperator>(sjBranches.size());
            for (ReduceSinkOperator rs : sjBranches) {
                selOps.add(OperatorUtils.ancestor(rs, SelectOperator.class, 0, 0, 0, 0));
            }
            long sjEntriesHint = SemiJoinReductionMerge.extractBloomEntriesHint(sjBranches);
            SelectOperator selectOp = SemiJoinReductionMerge.mergeSelectOps(sjMergeCandidate.getKey().source, selOps);
            GroupByOperator gbPartialOp = SemiJoinReductionMerge.createGroupBy(selectOp, selectOp, GroupByDesc.Mode.HASH, sjEntriesHint, hiveConf);
            ReduceSinkOperator rsPartialOp = SemiJoinReductionMerge.createReduceSink(gbPartialOp, NullOrdering.defaultNullOrder((Configuration)hiveConf));
            ((ReduceSinkDesc)rsPartialOp.getConf()).setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.QUICKSTART));
            GroupByOperator gbCompleteOp = SemiJoinReductionMerge.createGroupBy(selectOp, rsPartialOp, GroupByDesc.Mode.FINAL, sjEntriesHint, hiveConf);
            ReduceSinkOperator rsCompleteOp = SemiJoinReductionMerge.createReduceSink(gbCompleteOp, NullOrdering.defaultNullOrder((Configuration)hiveConf));
            TableScanOperator sjTargetTable = sjMergeCandidate.getKey().target;
            SemiJoinBranchInfo sjInfo = new SemiJoinBranchInfo(sjTargetTable, false);
            parseContext.getRsToSemiJoinBranchInfo().put(rsCompleteOp, sjInfo);
            RuntimeValuesInfo valuesInfo = SemiJoinReductionMerge.createRuntimeValuesInfo(rsCompleteOp, sjBranches, parseContext);
            parseContext.getRsToRuntimeValuesInfoMap().put(rsCompleteOp, valuesInfo);
            ExprNodeGenericFuncDesc sjPredicate = SemiJoinReductionMerge.createSemiJoinPredicate(sjBranches, valuesInfo, parseContext);
            for (Operator<OperatorDesc> op : sjTargetTable.getChildOperators()) {
                if (!(op instanceof FilterOperator)) continue;
                FilterDesc filter = (FilterDesc)((FilterOperator)op).getConf();
                filter.setPredicate(ExprNodeDescUtils.and(filter.getPredicate(), sjPredicate));
            }
            ((TableScanDesc)sjTargetTable.getConf()).setFilterExpr(ExprNodeDescUtils.and(((TableScanDesc)sjTargetTable.getConf()).getFilterExpr(), sjPredicate));
            for (ReduceSinkOperator rs : sjBranches) {
                GenTezUtils.removeSemiJoinOperator(parseContext, rs, sjTargetTable);
                GenTezUtils.removeBranch(rs);
            }
        }
        return parseContext;
    }

    private static Collection<Map.Entry<SJSourceTarget, List<ReduceSinkOperator>>> createMergeCandidates(Map<ReduceSinkOperator, SemiJoinBranchInfo> semijoins) {
        LinkedHashMap<SJSourceTarget, List> sjGroups = new LinkedHashMap<SJSourceTarget, List>();
        for (Map.Entry<ReduceSinkOperator, SemiJoinBranchInfo> smjEntry : semijoins.entrySet()) {
            TableScanOperator ts = smjEntry.getValue().getTsOp();
            SelectOperator selOp = OperatorUtils.ancestor(smjEntry.getKey(), SelectOperator.class, 0, 0, 0, 0);
            Operator<OperatorDesc> source = selOp.getParentOperators().get(0);
            Preconditions.checkState((selOp.getParentOperators().size() == 1 ? 1 : 0) != 0, (Object)"Semijoin branches should not have multiple parents");
            SJSourceTarget sjKey = new SJSourceTarget(source, ts);
            List ops = sjGroups.computeIfAbsent(sjKey, tableScanOperator -> new ArrayList());
            ops.add(smjEntry.getKey());
        }
        return sjGroups.entrySet();
    }

    private static long extractBloomEntriesHint(List<ReduceSinkOperator> sjBranches) {
        long bloomEntries = -1L;
        for (ReduceSinkOperator rs : sjBranches) {
            GroupByOperator gbOp = OperatorUtils.ancestor(rs, GroupByOperator.class, 0, 0, 0);
            List<GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator> blooms = FunctionUtils.extractEvaluators(((GroupByDesc)gbOp.getConf()).getAggregators(), GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator.class);
            Preconditions.checkState((blooms.size() == 1 ? 1 : 0) != 0);
            if (!blooms.get(0).hasHintEntries()) continue;
            bloomEntries = Math.max(bloomEntries, blooms.get(0).getExpectedEntries());
        }
        return bloomEntries;
    }

    private static ExprNodeGenericFuncDesc createSemiJoinPredicate(List<ReduceSinkOperator> sjBranches, RuntimeValuesInfo sjValueInfo, ParseContext context) {
        ArrayDeque<String> dynamicIds = new ArrayDeque<String>(sjValueInfo.getDynamicValueIDs());
        ArrayList<ExprNodeDesc> sjPredicates = new ArrayList<ExprNodeDesc>();
        ArrayList<ExprNodeDesc> hashArgs = new ArrayList<ExprNodeDesc>();
        for (ReduceSinkOperator rs : sjBranches) {
            RuntimeValuesInfo info = context.getRsToRuntimeValuesInfoMap().get(rs);
            Preconditions.checkState((info.getTargetColumns().size() == 1 ? 1 : 0) != 0, (Object)"Cannot handle multi-column semijoin branches.");
            ExprNodeDesc targetColumn = info.getTargetColumns().get(0);
            TypeInfo typeInfo = targetColumn.getTypeInfo();
            DynamicValue minDynamic = new DynamicValue((String)dynamicIds.poll(), typeInfo);
            DynamicValue maxDynamic = new DynamicValue((String)dynamicIds.poll(), typeInfo);
            List<ExprNodeDesc> betweenArgs = Arrays.asList(new ExprNodeConstantDesc(Boolean.FALSE), targetColumn, new ExprNodeDynamicValueDesc(minDynamic), new ExprNodeDynamicValueDesc(maxDynamic));
            ExprNodeGenericFuncDesc betweenExp = new ExprNodeGenericFuncDesc((TypeInfo)TypeInfoFactory.booleanTypeInfo, (GenericUDF)new GenericUDFBetween(), "between", betweenArgs);
            sjPredicates.add(betweenExp);
            hashArgs.add(targetColumn);
        }
        ExprNodeGenericFuncDesc hashExp = ExprNodeDescUtils.murmurHash(hashArgs);
        assert (dynamicIds.size() == 1) : "There should be one column left untreated the one with the bloom filter";
        DynamicValue bloomDynamic = new DynamicValue((String)dynamicIds.poll(), (TypeInfo)TypeInfoFactory.binaryTypeInfo);
        sjPredicates.add(new ExprNodeGenericFuncDesc((TypeInfo)TypeInfoFactory.booleanTypeInfo, (GenericUDF)new GenericUDFInBloomFilter(), "in_bloom_filter", Arrays.asList(hashExp, new ExprNodeDynamicValueDesc(bloomDynamic))));
        return ExprNodeDescUtils.and(sjPredicates);
    }

    private static RuntimeValuesInfo createRuntimeValuesInfo(ReduceSinkOperator rs, List<ReduceSinkOperator> sjBranches, ParseContext parseContext) {
        List<ExprNodeDesc> valueCols = ((ReduceSinkDesc)rs.getConf()).getValueCols();
        RuntimeValuesInfo info = new RuntimeValuesInfo();
        TableDesc rsFinalTableDesc = PlanUtils.getReduceValueTableDesc(PlanUtils.getFieldSchemasFromColumnList(valueCols, "_col"));
        ArrayList<String> dynamicValueIDs = new ArrayList<String>();
        for (ExprNodeDesc rsCol : valueCols) {
            dynamicValueIDs.add(rs.toString() + rsCol.getExprString());
        }
        info.setTableDesc(rsFinalTableDesc);
        info.setDynamicValueIDs(dynamicValueIDs);
        info.setColExprs(valueCols);
        ArrayList<ExprNodeDesc> targetTableExpressions = new ArrayList<ExprNodeDesc>();
        for (ReduceSinkOperator sjBranch : sjBranches) {
            RuntimeValuesInfo sjInfo = parseContext.getRsToRuntimeValuesInfoMap().get(sjBranch);
            Preconditions.checkState((sjInfo.getTargetColumns().size() == 1 ? 1 : 0) != 0, (Object)"Cannot handle multi-column semijoin branches.");
            targetTableExpressions.add(sjInfo.getTargetColumns().get(0));
        }
        info.setTargetColumns(targetTableExpressions);
        return info;
    }

    private static SelectOperator mergeSelectOps(Operator<?> parent, List<SelectOperator> selectOperators) {
        ArrayList<String> colNames = new ArrayList<String>();
        ArrayList<ExprNodeDesc> colDescs = new ArrayList<ExprNodeDesc>();
        ArrayList<ColumnInfo> columnInfos = new ArrayList<ColumnInfo>();
        HashMap<String, ExprNodeDesc> selectColumnExprMap = new HashMap<String, ExprNodeDesc>();
        for (SelectOperator sel : selectOperators) {
            Preconditions.checkState((((SelectDesc)sel.getConf()).getColList().size() == 1 ? 1 : 0) != 0);
            ExprNodeDesc col = ((SelectDesc)sel.getConf()).getColList().get(0);
            String colName = HiveConf.getColumnInternalName((int)colDescs.size());
            colNames.add(colName);
            columnInfos.add(new ColumnInfo(colName, col.getTypeInfo(), "", false));
            colDescs.add(col);
            selectColumnExprMap.put(colName, col);
        }
        ExprNodeGenericFuncDesc hashExp = ExprNodeDescUtils.murmurHash(colDescs);
        String hashName = HiveConf.getColumnInternalName((int)(colDescs.size() + 1));
        colNames.add(hashName);
        columnInfos.add(new ColumnInfo(hashName, hashExp.getTypeInfo(), "", false));
        ArrayList<ExprNodeDesc> selDescs = new ArrayList<ExprNodeDesc>(colDescs);
        selDescs.add(hashExp);
        SelectDesc select = new SelectDesc(selDescs, colNames);
        SelectOperator selectOp = (SelectOperator)OperatorFactory.getAndMakeChild(select, new RowSchema(columnInfos), parent, new Operator[0]);
        selectOp.setColumnExprMap(selectColumnExprMap);
        return selectOp;
    }

    private static ReduceSinkOperator createReduceSink(Operator<?> parentOp, NullOrdering nullOrder) throws SemanticException {
        ArrayList<ExprNodeDesc> valueCols = new ArrayList<ExprNodeDesc>();
        RowSchema parentSchema = parentOp.getSchema();
        ArrayList<String> outColNames = new ArrayList<String>();
        for (int i = 0; i < parentSchema.getSignature().size(); ++i) {
            ColumnInfo colInfo = parentSchema.getSignature().get(i);
            ExprNodeColumnDesc colExpr = new ExprNodeColumnDesc(colInfo.getType(), colInfo.getInternalName(), "", false);
            valueCols.add(colExpr);
            outColNames.add(SemanticAnalyzer.getColumnInternalName(i));
        }
        ReduceSinkDesc rsDesc = PlanUtils.getReduceSinkDesc(Collections.emptyList(), valueCols, outColNames, false, -1, 0, 1, AcidUtils.Operation.NOT_ACID, nullOrder);
        rsDesc.setColumnExprMap(Collections.emptyMap());
        return (ReduceSinkOperator)OperatorFactory.getAndMakeChild(rsDesc, new RowSchema(parentSchema), parentOp, new Operator[0]);
    }

    private static GroupByOperator createGroupBy(SelectOperator selectOp, Operator<?> parentOp, GroupByDesc.Mode gbMode, long bloomEntriesHint, HiveConf hiveConf) {
        List<ExprNodeDesc> params;
        GenericUDAFEvaluator.Mode udafMode = SemanticAnalyzer.groupByDescModeToUDAFMode(gbMode, false);
        switch (gbMode) {
            case FINAL: {
                params = SemiJoinReductionMerge.createGroupByAggregationParameters((ReduceSinkOperator)parentOp);
                break;
            }
            case HASH: {
                params = SemiJoinReductionMerge.createGroupByAggregationParameters(selectOp);
                break;
            }
            default: {
                throw new AssertionError((Object)(gbMode.toString() + " is not supported"));
            }
        }
        ArrayList<AggregationDesc> gbAggs = new ArrayList<AggregationDesc>();
        ArrayDeque<ExprNodeDesc> paramsCopy = new ArrayDeque<ExprNodeDesc>(params);
        while (paramsCopy.size() > 1) {
            gbAggs.add(SemiJoinReductionMerge.minAggregation(udafMode, (ExprNodeDesc)paramsCopy.poll()));
            gbAggs.add(SemiJoinReductionMerge.maxAggregation(udafMode, (ExprNodeDesc)paramsCopy.poll()));
        }
        gbAggs.add(SemiJoinReductionMerge.bloomFilterAggregation(udafMode, (ExprNodeDesc)paramsCopy.poll(), selectOp, bloomEntriesHint, hiveConf));
        assert (paramsCopy.size() == 0);
        ArrayList<String> gbOutputNames = new ArrayList<String>(gbAggs.size());
        ArrayList<ColumnInfo> gbColInfos = new ArrayList<ColumnInfo>(gbAggs.size());
        for (int i = 0; i < params.size(); ++i) {
            String colName = HiveConf.getColumnInternalName((int)i);
            gbOutputNames.add(colName);
            Object colType = i == params.size() - 1 ? TypeInfoFactory.binaryTypeInfo : params.get(i).getTypeInfo();
            gbColInfos.add(new ColumnInfo(colName, (TypeInfo)colType, "", false));
        }
        float groupByMemoryUsage = HiveConf.getFloatVar((Configuration)hiveConf, (HiveConf.ConfVars)HiveConf.ConfVars.HIVEMAPAGGRHASHMEMORY);
        float memoryThreshold = HiveConf.getFloatVar((Configuration)hiveConf, (HiveConf.ConfVars)HiveConf.ConfVars.HIVEMAPAGGRMEMORYTHRESHOLD);
        float minReductionHashAggr = HiveConf.getFloatVar((Configuration)hiveConf, (HiveConf.ConfVars)HiveConf.ConfVars.HIVEMAPAGGRHASHMINREDUCTION);
        float minReductionHashAggrLowerBound = HiveConf.getFloatVar((Configuration)hiveConf, (HiveConf.ConfVars)HiveConf.ConfVars.HIVEMAPAGGRHASHMINREDUCTIONLOWERBOUND);
        GroupByDesc groupBy = new GroupByDesc(gbMode, gbOutputNames, Collections.emptyList(), gbAggs, false, groupByMemoryUsage, memoryThreshold, minReductionHashAggr, minReductionHashAggrLowerBound, null, false, -1, false);
        groupBy.setColumnExprMap(Collections.emptyMap());
        return (GroupByOperator)OperatorFactory.getAndMakeChild(groupBy, new RowSchema(gbColInfos), parentOp, new Operator[0]);
    }

    private static List<ExprNodeDesc> createGroupByAggregationParameters(SelectOperator selectOp) {
        ArrayList<ExprNodeDesc> params = new ArrayList<ExprNodeDesc>();
        for (ColumnInfo c : selectOp.getSchema().getSignature()) {
            String name = c.getInternalName();
            ExprNodeColumnDesc p = new ExprNodeColumnDesc(new ColumnInfo(name, c.getType(), "", false));
            params.add(p);
            params.add(p);
        }
        params.remove(params.size() - 1);
        return params;
    }

    private static List<ExprNodeDesc> createGroupByAggregationParameters(ReduceSinkOperator reduceOp) {
        ArrayList<ExprNodeDesc> params = new ArrayList<ExprNodeDesc>();
        for (ColumnInfo c : reduceOp.getSchema().getSignature()) {
            String name = (Object)((Object)Utilities.ReduceField.VALUE) + "." + c.getInternalName();
            params.add(new ExprNodeColumnDesc(new ColumnInfo(name, c.getType(), "", false)));
        }
        return params;
    }

    private static AggregationDesc minAggregation(GenericUDAFEvaluator.Mode mode, ExprNodeDesc col) {
        List<ExprNodeDesc> p = Collections.singletonList(col);
        return new AggregationDesc("min", new GenericUDAFMin.GenericUDAFMinEvaluator(), p, false, mode);
    }

    private static AggregationDesc maxAggregation(GenericUDAFEvaluator.Mode mode, ExprNodeDesc col) {
        List<ExprNodeDesc> p = Collections.singletonList(col);
        return new AggregationDesc("max", new GenericUDAFMax.GenericUDAFMaxEvaluator(), p, false, mode);
    }

    private static AggregationDesc bloomFilterAggregation(GenericUDAFEvaluator.Mode mode, ExprNodeDesc col, SelectOperator source, long numEntriesHint, HiveConf conf) {
        GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator bloomFilterEval = new GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator();
        bloomFilterEval.setSourceOperator(source);
        bloomFilterEval.setMaxEntries(conf.getLongVar(HiveConf.ConfVars.TEZ_MAX_BLOOM_FILTER_ENTRIES));
        bloomFilterEval.setMinEntries(conf.getLongVar(HiveConf.ConfVars.TEZ_MIN_BLOOM_FILTER_ENTRIES));
        bloomFilterEval.setFactor(conf.getFloatVar(HiveConf.ConfVars.TEZ_BLOOM_FILTER_FACTOR));
        bloomFilterEval.setHintEntries(numEntriesHint);
        List<ExprNodeDesc> p = Collections.singletonList(col);
        AggregationDesc bloom = new AggregationDesc("bloom_filter", bloomFilterEval, p, false, mode);
        bloom.setGenericUDAFWritableEvaluator(bloomFilterEval);
        return bloom;
    }

    private static final class SJSourceTarget {
        private final Operator<?> source;
        private final TableScanOperator target;

        public SJSourceTarget(Operator<?> source, TableScanOperator target) {
            this.source = source;
            this.target = target;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            SJSourceTarget that = (SJSourceTarget)o;
            if (!this.source.equals(that.source)) {
                return false;
            }
            return this.target.equals(that.target);
        }

        public int hashCode() {
            int result = this.source.hashCode();
            result = 31 * result + this.target.hashCode();
            return result;
        }
    }
}

