/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import org.apache.calcite.util.Pair;
import org.apache.commons.collections4.multimap.ArrayListValuedHashMap;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.optimizer.graph.OperatorGraph;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelEdgeFixer
extends Transform {
    protected static final Logger LOG = LoggerFactory.getLogger(ParallelEdgeFixer.class);

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        OperatorGraph og = new OperatorGraph(pctx);
        this.fixParallelEdges(og);
        return pctx;
    }

    private void fixParallelEdges(OperatorGraph og) throws SemanticException {
        ArrayListValuedHashMap edgeOperators = new ArrayListValuedHashMap();
        for (OperatorGraph.Cluster c : og.getClusters()) {
            for (Operator<?> o : c.getMembers()) {
                for (Operator<OperatorDesc> p : o.getParentOperators()) {
                    OperatorGraph.Cluster parentCluster = og.clusterOf(p);
                    if (parentCluster == c) continue;
                    edgeOperators.put((Object)new Pair((Object)parentCluster, (Object)c), (Object)new Pair(p, o));
                }
            }
        }
        for (Pair key : edgeOperators.keySet()) {
            List values = edgeOperators.get((Object)key);
            if (values.size() <= 1) continue;
            values.sort(new OperatorPairComparator());
            this.removeOneEdge(values);
            for (Pair pair : values) {
                this.fixParallelEdge((Operator)pair.left, (Operator)pair.right);
            }
        }
    }

    private void removeOneEdge(List<Pair<Operator<?>, Operator<?>>> values) {
        Pair<Operator<?>, Operator<?>> toKeep = null;
        for (Pair<Operator<?>, Operator<?>> pair : values) {
            if (this.isParallelEdgeSupported(pair)) continue;
            if (toKeep != null) {
                throw new RuntimeException("More than one operators which may not reshuffled!");
            }
            toKeep = pair;
        }
        if (toKeep == null) {
            toKeep = values.get(values.size() - 1);
        }
        values.remove(toKeep);
    }

    public boolean isParallelEdgeSupported(Pair<Operator<?>, Operator<?>> pair) {
        Operator rs = (Operator)pair.left;
        if (rs instanceof ReduceSinkOperator && !ParallelEdgeFixer.colMappingInverseKeys((ReduceSinkOperator)rs).isPresent()) {
            return false;
        }
        Operator child = (Operator)pair.right;
        if (child instanceof MapJoinOperator) {
            return true;
        }
        return child instanceof TableScanOperator;
    }

    private void fixParallelEdge(Operator<? extends OperatorDesc> p, Operator<?> o) throws SemanticException {
        LOG.info("Fixing parallel by adding a concentrator RS between {} -> {}", p, o);
        ReduceSinkDesc conf = (ReduceSinkDesc)p.getConf();
        ReduceSinkDesc newConf = (ReduceSinkDesc)conf.clone();
        Operator<SelectDesc> newSEL = this.buildSEL(p, conf);
        Operator<ReduceSinkDesc> newRS = OperatorFactory.getAndMakeChild(p.getCompilationOpContext(), newConf, new ArrayList<Operator<? extends OperatorDesc>>());
        conf.setOutputName("forward_to_" + newRS);
        conf.setTag(0);
        newConf.setKeyCols(new ArrayList<ExprNodeDesc>(conf.getKeyCols()));
        newRS.setSchema(new RowSchema(p.getSchema()));
        p.replaceChild(o, newSEL);
        newSEL.setParentOperators(Lists.newArrayList((Object[])new Operator[]{p}));
        newSEL.setChildOperators(Lists.newArrayList((Object[])new Operator[]{newRS}));
        newRS.setParentOperators(Lists.newArrayList((Object[])new Operator[]{newSEL}));
        newRS.setChildOperators(Lists.newArrayList((Object[])new Operator[]{o}));
        o.replaceParent(p, newRS);
    }

    private Operator<SelectDesc> buildSEL(Operator<? extends OperatorDesc> p, ReduceSinkDesc conf) throws SemanticException {
        ArrayList<ExprNodeDesc> colList = new ArrayList<ExprNodeDesc>();
        ArrayList<String> outputColumnNames = new ArrayList<String>();
        ArrayList<ColumnInfo> newColumns = new ArrayList<ColumnInfo>();
        Set<String> inverseKeys = ParallelEdgeFixer.colMappingInverseKeys((ReduceSinkOperator)p).get();
        for (String colName : inverseKeys) {
            ExprNodeDesc expr = conf.getColumnExprMap().get(colName);
            ExprNodeColumnDesc colRef = new ExprNodeColumnDesc(expr.getTypeInfo(), colName, colName, false);
            colList.add(colRef);
            String newColName = ParallelEdgeFixer.extractColumnName(expr);
            outputColumnNames.add(newColName);
            ColumnInfo newColInfo = new ColumnInfo(p.getSchema().getColumnInfo(colName));
            newColInfo.setInternalName(newColName);
            newColumns.add(newColInfo);
        }
        SelectDesc selConf = new SelectDesc(colList, outputColumnNames);
        Operator<SelectDesc> newSEL = OperatorFactory.getAndMakeChild(p.getCompilationOpContext(), selConf, new ArrayList<Operator<? extends OperatorDesc>>());
        newSEL.setSchema(new RowSchema(newColumns));
        return newSEL;
    }

    private static String extractColumnName(ExprNodeDesc expr) throws SemanticException {
        if (expr instanceof ExprNodeColumnDesc) {
            ExprNodeColumnDesc exprNodeColumnDesc = (ExprNodeColumnDesc)expr;
            return exprNodeColumnDesc.getColumn();
        }
        if (expr instanceof ExprNodeConstantDesc) {
            ExprNodeConstantDesc exprNodeConstantDesc = (ExprNodeConstantDesc)expr;
            return exprNodeConstantDesc.getFoldedFromCol();
        }
        throw new SemanticException("unexpected mapping expression!");
    }

    public static Optional<Set<String>> colMappingInverseKeys(ReduceSinkOperator rs) {
        HashMap<String, String> ret = new HashMap<String, String>();
        Map<String, ExprNodeDesc> exprMap = rs.getColumnExprMap();
        try {
            for (Map.Entry<String, ExprNodeDesc> e : exprMap.entrySet()) {
                ret.put(ParallelEdgeFixer.extractColumnName(e.getValue()), e.getKey());
            }
            return Optional.of(new TreeSet(ret.values()));
        }
        catch (SemanticException e) {
            return Optional.empty();
        }
    }

    private static class OperatorPairComparator
    implements Comparator<Pair<Operator<?>, Operator<?>>> {
        private OperatorPairComparator() {
        }

        @Override
        public int compare(Pair<Operator<?>, Operator<?>> o1, Pair<Operator<?>, Operator<?>> o2) {
            return this.sig(o1).compareTo(this.sig(o2));
        }

        private String sig(Pair<Operator<?>, Operator<?>> o1) {
            return ((Operator)o1.left).toString() + ((Operator)o1.right).toString();
        }
    }
}

