/*
 * Decompiled with CFR 0.152.
 */
package org.apache.impala.calcite.coercenodes;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.impala.calcite.coercenodes.CoerceOperandShuttle;
import org.apache.impala.calcite.functions.FunctionResolver;
import org.apache.impala.calcite.rel.node.ImpalaPlanRel;
import org.apache.impala.calcite.type.ImpalaTypeConverter;
import org.apache.impala.catalog.Function;
import org.apache.impala.catalog.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CoerceNodes {
    protected static final Logger LOG = LoggerFactory.getLogger((String)CoerceNodes.class.getName());

    public static RelNode coerceNodes(RelNode relNode, RexBuilder rexBuilder) {
        RelNode newRelNode = CoerceNodes.coerceNodesInternal(relNode, rexBuilder);
        return newRelNode != null ? newRelNode : relNode;
    }

    private static RelNode coerceNodesInternal(RelNode relNode, RexBuilder rexBuilder) {
        boolean isInputChanged = false;
        ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
        for (RelNode input : relNode.getInputs()) {
            RelNode changedInput = CoerceNodes.coerceNodesInternal(input, rexBuilder);
            isInputChanged |= changedInput != input;
            newInputs.add(changedInput);
        }
        switch (ImpalaPlanRel.getRelNodeType(relNode)) {
            case AGGREGATE: {
                return CoerceNodes.processAggNode(relNode, newInputs, rexBuilder, isInputChanged);
            }
            case FILTER: {
                return CoerceNodes.processFilterNode(relNode, newInputs, rexBuilder, isInputChanged);
            }
            case JOIN: {
                return CoerceNodes.processJoinNode(relNode, newInputs, rexBuilder, isInputChanged);
            }
            case PROJECT: {
                return CoerceNodes.processProjectNode(relNode, newInputs, rexBuilder, isInputChanged);
            }
            case SORT: {
                return CoerceNodes.processSortNode(relNode, newInputs, rexBuilder, isInputChanged);
            }
            case UNION: {
                return CoerceNodes.processUnionNode(relNode, newInputs, rexBuilder, isInputChanged);
            }
            case VALUES: {
                return CoerceNodes.processValuesNode(relNode, newInputs, rexBuilder, isInputChanged);
            }
            case HDFSSCAN: {
                return relNode;
            }
        }
        throw new RuntimeException("Unrecognized RelNode: " + relNode);
    }

    private static RelNode processFilterNode(RelNode relNode, List<RelNode> inputs, RexBuilder rexBuilder, boolean isInputChanged) {
        LogicalFilter filter = (LogicalFilter)relNode;
        RexNode condition = filter.getCondition();
        List<RexNode> changedRexNodes = CoerceNodes.processRexNodes(relNode, inputs, (List<RexNode>)ImmutableList.of((Object)condition));
        if (changedRexNodes == null && !isInputChanged) {
            return relNode;
        }
        RexNode newCondition = changedRexNodes == null ? condition : changedRexNodes.get(0);
        return filter.copy(filter.getTraitSet(), inputs.get(0), newCondition);
    }

    private static RelNode processJoinNode(RelNode relNode, List<RelNode> inputs, RexBuilder rexBuilder, boolean isInputChanged) {
        LogicalJoin join = (LogicalJoin)relNode;
        RexNode condition = join.getCondition();
        List<RexNode> changedRexNodes = CoerceNodes.processRexNodes(relNode, inputs, (List<RexNode>)ImmutableList.of((Object)condition));
        if (changedRexNodes == null && !isInputChanged) {
            return relNode;
        }
        RexNode newCondition = changedRexNodes == null ? condition : changedRexNodes.get(0);
        return join.copy(join.getTraitSet(), newCondition, inputs.get(0), inputs.get(1), join.getJoinType(), join.isSemiJoinDone());
    }

    private static RelNode processProjectNode(RelNode relNode, List<RelNode> inputs, RexBuilder rexBuilder, boolean isInputChanged) {
        LogicalProject project = (LogicalProject)relNode;
        List<RexNode> projects = project.getProjects();
        List<RexNode> changedRexNodes = CoerceNodes.processRexNodes(relNode, inputs, projects);
        if (changedRexNodes == null && !isInputChanged) {
            return relNode;
        }
        List<RexNode> newProjects = changedRexNodes == null ? projects : changedRexNodes;
        RelDataTypeFactory factory = rexBuilder.getTypeFactory();
        List typeList = Util.transform(newProjects, RexNode::getType);
        RelDataType rowType = factory.createStructType(typeList, project.getRowType().getFieldNames());
        return project.copy(project.getTraitSet(), inputs.get(0), newProjects, rowType);
    }

    private static RelNode processSortNode(RelNode relNode, List<RelNode> inputs, RexBuilder rexBuilder, boolean isInputChanged) {
        return isInputChanged ? relNode.copy(relNode.getTraitSet(), inputs) : relNode;
    }

    private static RelNode processAggNode(RelNode relNode, List<RelNode> inputs, RexBuilder rexBuilder, boolean isInputChanged) {
        LogicalAggregate agg = (LogicalAggregate)relNode;
        int numInputFields = agg.getInput(0).getRowType().getFieldCount();
        ArrayList<AggregateCall> transformedAggCallList = new ArrayList<AggregateCall>();
        ArrayList<RexNode> newProjectFields = new ArrayList<RexNode>();
        ArrayList<String> newProjectFieldNames = new ArrayList<String>();
        for (AggregateCall aggCall : agg.getAggCallList()) {
            List<RelDataType> operandTypes;
            if (CoerceNodes.matchesSignature(aggCall, operandTypes = CoerceNodes.getOperandTypes(inputs.get(0), aggCall))) {
                transformedAggCallList.add(aggCall);
                continue;
            }
            List<RelDataType> newOperandTypes = CoerceNodes.getCastedOperandTypes(aggCall, operandTypes);
            AggregateCall newAggCall = CoerceNodes.getNewAggCall(aggCall, operandTypes, newOperandTypes, numInputFields + newProjectFields.size());
            transformedAggCallList.add(newAggCall);
            newProjectFields.addAll(CoerceNodes.getNewProjectFields(agg.getCluster().getRexBuilder(), aggCall, operandTypes, newOperandTypes));
            newProjectFieldNames.addAll(CoerceNodes.getNewProjectFieldNames(agg.getInput(0), aggCall, operandTypes, newOperandTypes));
        }
        if (newProjectFields.isEmpty()) {
            return isInputChanged ? relNode.copy(relNode.getTraitSet(), inputs) : relNode;
        }
        RelNode project = CoerceNodes.createProject(inputs.get(0), newProjectFields, newProjectFieldNames);
        return LogicalAggregate.create((RelNode)project, (ImmutableBitSet)agg.getGroupSet(), (List)agg.getGroupSets(), transformedAggCallList);
    }

    private static RelNode processUnionNode(RelNode relNode, List<RelNode> inputs, RexBuilder rexBuilder, boolean isInputChanged) {
        LogicalUnion union = (LogicalUnion)relNode;
        if (!isInputChanged && inputs.size() == 1) {
            return relNode;
        }
        List<RelDataType> commonRowType = CoerceNodes.getCompatibleRowType(relNode, inputs, rexBuilder);
        boolean inputsChanged = isInputChanged || CoerceNodes.haveTypesChanged(commonRowType, union.getRowType().getFieldList());
        ArrayList<RelNode> changedRelNodes = new ArrayList<RelNode>();
        for (RelNode input : inputs) {
            RelNode changedRelNode = CoerceNodes.getChangedUnionInput(input, commonRowType, rexBuilder);
            boolean inputChanged = !changedRelNode.equals(input);
            changedRelNodes.add(inputChanged ? changedRelNode : input);
            inputsChanged |= inputChanged;
        }
        return inputsChanged ? LogicalUnion.create(changedRelNodes, (boolean)union.all) : relNode;
    }

    private static RelNode processValuesNode(RelNode relNode, List<RelNode> inputs, RexBuilder rexBuilder, boolean isInputChanged) {
        LogicalValues values = (LogicalValues)relNode;
        if (values.getTuples().size() == 0) {
            return relNode;
        }
        int nColumns = values.getRowType().getFieldList().size();
        List<RelDataType> relDataTypes = Arrays.asList(new RelDataType[nColumns]);
        boolean needProject = false;
        for (List tuple : values.getTuples()) {
            List<RexNode> rexNodes = CoerceNodes.castToRexNodeList(tuple);
            List<RexNode> changedRexNodes = CoerceNodes.processRexNodes(relNode, inputs, rexNodes);
            if (changedRexNodes == null) continue;
            needProject = true;
            Preconditions.checkState((changedRexNodes.size() == relDataTypes.size() ? 1 : 0) != 0);
            for (int i = 0; i < changedRexNodes.size(); ++i) {
                if (changedRexNodes.get(i) != null) {
                    Preconditions.checkState((changedRexNodes.get(i).getKind() == SqlKind.CAST || changedRexNodes.get(i) instanceof RexLiteral ? 1 : 0) != 0);
                }
                relDataTypes.set(i, CoerceNodes.getCompatibleDataType(relDataTypes.get(i), changedRexNodes.get(i).getType(), rexBuilder));
            }
        }
        if (!needProject) {
            return relNode;
        }
        ArrayList<RexInputRef> projects = new ArrayList<RexInputRef>();
        for (int i = 0; i < relDataTypes.size(); ++i) {
            RexInputRef inputRef = rexBuilder.makeInputRef((RelNode)values, i);
            RexInputRef project = relDataTypes.get(i) != null ? rexBuilder.makeCast(relDataTypes.get(i), (RexNode)inputRef) : inputRef;
            projects.add(project);
        }
        return LogicalProject.create((RelNode)values, new ArrayList(), projects, (List)values.getRowType().getFieldNames());
    }

    private static List<RexNode> processRexNodes(RelNode relNode, List<RelNode> inputs, List<RexNode> rexNodes) {
        RexBuilder rexBuilder = relNode.getCluster().getRexBuilder();
        CoerceOperandShuttle shuttle = new CoerceOperandShuttle(relNode.getCluster().getTypeFactory(), rexBuilder, inputs);
        ArrayList<RexNode> changedRexNodes = new ArrayList<RexNode>();
        boolean rexNodeChanged = false;
        for (RexNode rexNode : rexNodes) {
            RexNode changedRexNode = shuttle.apply(rexNode);
            changedRexNode = RexUtil.pullFactors((RexBuilder)rexBuilder, (RexNode)changedRexNode);
            changedRexNode = RexUtil.toCnf((RexBuilder)rexBuilder, (int)100, (RexNode)changedRexNode);
            changedRexNodes.add(changedRexNode);
            rexNodeChanged |= changedRexNode != rexNode;
        }
        return rexNodeChanged ? changedRexNodes : null;
    }

    private static List<RexNode> castToRexNodeList(List<RexLiteral> literalList) {
        Class<RexNode> clazz = RexNode.class;
        return literalList.stream().map(clazz::cast).collect(Collectors.toList());
    }

    private static RelDataType getCompatibleDataType(RelDataType dt1, RelDataType dt2, RexBuilder rexBuilder) {
        if (dt1 == null) {
            return dt2;
        }
        if (dt2 == null) {
            return dt1;
        }
        return ImpalaTypeConverter.getCompatibleType((Collection<RelDataType>)ImmutableList.of((Object)dt1, (Object)dt2), rexBuilder.getTypeFactory());
    }

    private static List<RelDataType> getCompatibleRowType(RelNode origRelNode, List<RelNode> inputs, RexBuilder rexBuilder) {
        RelDataTypeFactory factory = rexBuilder.getTypeFactory();
        ArrayList<RelDataType> finalTypes = new ArrayList<RelDataType>();
        for (RelDataTypeField field : inputs.get(0).getRowType().getFieldList()) {
            finalTypes.add(field.getType());
        }
        for (int j = 1; j < inputs.size(); ++j) {
            RelNode input = inputs.get(j);
            for (int i = 0; i < input.getRowType().getFieldList().size(); ++i) {
                RelDataType type0 = (RelDataType)finalTypes.get(i);
                RelDataType type1 = ((RelDataTypeField)input.getRowType().getFieldList().get(i)).getType();
                finalTypes.set(i, ImpalaTypeConverter.getCompatibleType(type0, type1, factory));
            }
        }
        return finalTypes;
    }

    private static RelNode getChangedUnionInput(RelNode relNode, List<RelDataType> commonRowType, RexBuilder rexBuilder) {
        boolean changed = false;
        ArrayList<RexInputRef> projects = new ArrayList<RexInputRef>();
        for (int i = 0; i < relNode.getRowType().getFieldList().size(); ++i) {
            RexInputRef inputRef = rexBuilder.makeInputRef(relNode, i);
            RelDataType inputType = ((RelDataTypeField)relNode.getRowType().getFieldList().get(i)).getType();
            boolean projectChanged = !inputType.equals(commonRowType.get(i));
            RexInputRef project = projectChanged ? rexBuilder.makeCast(commonRowType.get(i), (RexNode)inputRef) : inputRef;
            projects.add(project);
            changed |= projectChanged;
        }
        if (!changed) {
            return relNode;
        }
        RelDataTypeFactory factory = rexBuilder.getTypeFactory();
        RelDataType rowType = factory.createStructType(commonRowType, relNode.getRowType().getFieldNames());
        return LogicalProject.create((RelNode)relNode, new ArrayList(), projects, (RelDataType)rowType);
    }

    private static boolean haveTypesChanged(List<RelDataType> commonRowType, List<RelDataTypeField> fields) {
        Preconditions.checkState((commonRowType.size() == fields.size() ? 1 : 0) != 0);
        for (int i = 0; i < commonRowType.size(); ++i) {
            if (commonRowType.get(i).equals(fields.get(i).getType())) continue;
            return true;
        }
        return false;
    }

    private static boolean matchesSignature(AggregateCall aggCall, List<RelDataType> operandTypes) {
        if (aggCall.getAggregation().getName().toLowerCase().equals("single_value")) {
            return true;
        }
        for (RelDataType relDataType : operandTypes) {
            if (relDataType.getSqlTypeName() != SqlTypeName.NULL) continue;
            return false;
        }
        Function fn = FunctionResolver.getExactFunction(aggCall.getAggregation().getName(), operandTypes);
        if (fn == null) {
            return false;
        }
        RelDataType retType = ImpalaTypeConverter.getRelDataType(fn.getReturnType());
        return retType.getSqlTypeName().equals((Object)aggCall.getType().getSqlTypeName());
    }

    private static List<RelDataType> getOperandTypes(RelNode input, AggregateCall aggCall) {
        ArrayList<RelDataType> operandTypes = new ArrayList<RelDataType>();
        for (Integer i : aggCall.getArgList()) {
            operandTypes.add(((RelDataTypeField)input.getRowType().getFieldList().get(i)).getType());
        }
        return operandTypes;
    }

    private static List<RelDataType> getCastedOperandTypes(AggregateCall aggCall, List<RelDataType> operandTypes) {
        Function fn = FunctionResolver.getSupertypeFunction(aggCall.getAggregation().getName(), operandTypes);
        Preconditions.checkNotNull((Object)fn, (Object)("Could not find matching functions for " + aggCall.getAggregation().getName()));
        RelDataType retType = ImpalaTypeConverter.getRelDataType(fn.getReturnType());
        Preconditions.checkState((retType.getSqlTypeName().equals((Object)aggCall.getType().getSqlTypeName()) || aggCall.getType().getSqlTypeName().equals((Object)SqlTypeName.NULL) ? 1 : 0) != 0);
        ArrayList<RelDataType> newOperandTypes = new ArrayList<RelDataType>();
        for (int i = 0; i < operandTypes.size(); ++i) {
            Type t = i < fn.getArgs().length ? fn.getArgs()[i] : fn.getArgs()[fn.getArgs().length - 1];
            newOperandTypes.add(ImpalaTypeConverter.getRelDataType(t));
        }
        return newOperandTypes;
    }

    private static AggregateCall getNewAggCall(AggregateCall aggCall, List<RelDataType> operandTypes, List<RelDataType> newOperandTypes, int numProjects) {
        ArrayList<Integer> newArgList = new ArrayList<Integer>();
        Preconditions.checkState((aggCall.getArgList().size() == operandTypes.size() ? 1 : 0) != 0);
        Preconditions.checkState((operandTypes.size() == newOperandTypes.size() ? 1 : 0) != 0);
        for (int i = 0; i < operandTypes.size(); ++i) {
            boolean typesEqual = CoerceNodes.areSqlTypesEqual(operandTypes.get(i), newOperandTypes.get(i));
            int newArg = typesEqual ? (Integer)aggCall.getArgList().get(i) : numProjects++;
            newArgList.add(newArg);
        }
        return aggCall.withArgList(newArgList);
    }

    private static List<RexNode> getNewProjectFields(RexBuilder rexBuilder, AggregateCall aggCall, List<RelDataType> operandTypes, List<RelDataType> newOperandTypes) {
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>();
        for (int i = 0; i < operandTypes.size(); ++i) {
            if (CoerceNodes.areSqlTypesEqual(operandTypes.get(i), newOperandTypes.get(i))) continue;
            RexInputRef inputRef = rexBuilder.makeInputRef(operandTypes.get(i), ((Integer)aggCall.getArgList().get(i)).intValue());
            RexNode newProject = rexBuilder.makeCast(newOperandTypes.get(i), (RexNode)inputRef);
            newProjects.add(newProject);
        }
        return newProjects;
    }

    private static List<String> getNewProjectFieldNames(RelNode input, AggregateCall aggCall, List<RelDataType> operandTypes, List<RelDataType> newOperandTypes) {
        ArrayList<String> newNames = new ArrayList<String>();
        for (int i = 0; i < operandTypes.size(); ++i) {
            if (CoerceNodes.areSqlTypesEqual(operandTypes.get(i), newOperandTypes.get(i))) continue;
            String precastFieldName = (String)input.getRowType().getFieldNames().get((Integer)aggCall.getArgList().get(i));
            newNames.add("cast_" + precastFieldName);
        }
        return newNames;
    }

    private static boolean areSqlTypesEqual(RelDataType r1, RelDataType r2) {
        return r1.getSqlTypeName().equals((Object)r2.getSqlTypeName());
    }

    private static RelNode createProject(RelNode input, List<RexNode> newProjectFields, List<String> newFieldNames) {
        ArrayList<Object> projects = new ArrayList<Object>();
        RexBuilder rexBuilder = input.getCluster().getRexBuilder();
        for (int i = 0; i < input.getRowType().getFieldCount(); ++i) {
            projects.add(rexBuilder.makeInputRef(((RelDataTypeField)input.getRowType().getFieldList().get(i)).getType(), i));
        }
        ArrayList<String> fieldNames = new ArrayList<String>(input.getRowType().getFieldNames());
        projects.addAll(newProjectFields);
        fieldNames.addAll(newFieldNames);
        return RelFactories.DEFAULT_PROJECT_FACTORY.createProject(input, new ArrayList(), projects, fieldNames, new HashSet());
    }
}

