/*
 * Decompiled with CFR 0.152.
 */
package org.apache.impala.calcite.coercenodes;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlDatetimePlusOperator;
import org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.Util;
import org.apache.impala.calcite.functions.FunctionResolver;
import org.apache.impala.calcite.functions.ImplicitTypeChecker;
import org.apache.impala.calcite.operators.ImpalaDecodeFunction;
import org.apache.impala.calcite.type.ImpalaTypeConverter;
import org.apache.impala.catalog.Function;
import org.apache.impala.catalog.ScalarType;
import org.apache.impala.catalog.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CoerceOperandShuttle
extends RexShuttle {
    protected static final Logger LOG = LoggerFactory.getLogger((String)CoerceOperandShuttle.class.getName());
    private final RelDataTypeFactory factory;
    private final RexBuilder rexBuilder;
    private final List<RelNode> inputs;
    public static Set<SqlKind> NO_CAST_OPERATORS = ImmutableSet.builder().add((Object)SqlKind.CAST).add((Object)SqlKind.OR).add((Object)SqlKind.AND).build();

    public CoerceOperandShuttle(RelDataTypeFactory factory, RexBuilder rexBuilder, List<RelNode> inputs) {
        this.factory = factory;
        this.rexBuilder = rexBuilder;
        this.inputs = inputs;
    }

    public RexNode visitCall(RexCall call) {
        if (call.getOperator().getKind().equals((Object)SqlKind.SEARCH)) {
            return this.visitCall((RexCall)RexUtil.expandSearch((RexBuilder)this.rexBuilder, null, (RexNode)call));
        }
        RexCall castedOperandsCall = (RexCall)super.visitCall(call);
        if (!this.isCastingNeeded(castedOperandsCall = (RexCall)RexUtil.flatten((RexBuilder)this.rexBuilder, (RexNode)castedOperandsCall))) {
            return castedOperandsCall;
        }
        if (castedOperandsCall.getOperator().getName().equals("DECODE")) {
            return CoerceOperandShuttle.castDecodedFunction(castedOperandsCall, this.factory, this.rexBuilder);
        }
        Function fn = FunctionResolver.getSupertypeFunction(castedOperandsCall);
        if (fn == null) {
            throw new RuntimeException("Could not find a matching signature for call " + call);
        }
        RelDataType retType = this.getReturnType((RexNode)castedOperandsCall, fn.getReturnType());
        Preconditions.checkState((!SqlTypeUtil.isDecimal((RelDataType)retType) || SqlTypeUtil.isDecimal((RelDataType)castedOperandsCall.getType()) ? 1 : 0) != 0);
        if (SqlTypeUtil.isDecimal((RelDataType)retType)) {
            retType = castedOperandsCall.getType();
        }
        List<RexNode> newOperands = CoerceOperandShuttle.getCastedArgTypes(fn, castedOperandsCall.getOperands(), retType, this.factory, this.rexBuilder);
        return retType.equals(castedOperandsCall.getType()) && newOperands.equals(castedOperandsCall.getOperands()) ? castedOperandsCall : (RexCall)this.rexBuilder.makeCall(retType, castedOperandsCall.getOperator(), newOperands);
    }

    public RexNode visitOver(RexOver over) {
        RexOver castedOver = (RexOver)super.visitOver(over);
        Function fn = FunctionResolver.getSupertypeFunction((RexCall)castedOver);
        if (fn == null) {
            throw new RuntimeException("Could not find a matching signature for call " + over);
        }
        RelDataType retType = this.getReturnType((RexNode)castedOver, fn.getReturnType());
        List<RexNode> newOperands = CoerceOperandShuttle.getCastedArgTypes(fn, castedOver.getOperands(), retType, this.factory, this.rexBuilder);
        return retType.equals(castedOver.getType()) && newOperands.equals(castedOver.getOperands()) ? castedOver : (RexOver)this.rexBuilder.makeOver(retType, castedOver.getAggOperator(), newOperands, (List)castedOver.getWindow().partitionKeys, castedOver.getWindow().orderKeys, castedOver.getWindow().getLowerBound(), castedOver.getWindow().getUpperBound(), castedOver.getWindow().isRows(), true, false, castedOver.isDistinct(), castedOver.ignoreNulls());
    }

    public RexNode visitLiteral(RexLiteral literal) {
        if (literal.getType().getSqlTypeName().equals((Object)SqlTypeName.CHAR)) {
            return this.rexBuilder.makeLiteral((Object)RexLiteral.stringValue((RexNode)literal), ImpalaTypeConverter.getRelDataType((Type)Type.STRING), true, true);
        }
        if (literal.getType().getSqlTypeName().equals((Object)SqlTypeName.INTEGER)) {
            BigDecimal bd0 = (BigDecimal)literal.getValueAs(BigDecimal.class);
            RelDataType type = ImpalaTypeConverter.getLiteralDataType(bd0, literal.getType());
            return this.rexBuilder.makeLiteral((Object)bd0, type);
        }
        return literal;
    }

    public RexNode visitInputRef(RexInputRef inputRef) {
        RelDataType inputRefIndexType = this.getInputRefIndexType(this.inputs, inputRef.getIndex());
        return inputRef.getType().equals(inputRefIndexType) ? inputRef : this.rexBuilder.makeInputRef(inputRefIndexType, inputRef.getIndex());
    }

    private RelDataType getReturnType(RexNode rexNode, Type impalaReturnType) {
        RelDataType retType = ImpalaTypeConverter.getRelDataType(impalaReturnType);
        Preconditions.checkState((!SqlTypeUtil.isDecimal((RelDataType)retType) || SqlTypeUtil.isDecimal((RelDataType)rexNode.getType()) ? 1 : 0) != 0);
        if (SqlTypeUtil.isDecimal((RelDataType)retType)) {
            retType = rexNode.getType();
        }
        return retType;
    }

    private boolean isCastingNeeded(RexCall rexCall) {
        if (NO_CAST_OPERATORS.contains(rexCall.getOperator().getKind())) {
            return false;
        }
        if (CoerceOperandShuttle.isTimestampArithExpr(rexCall)) {
            return false;
        }
        return !rexCall.getOperator().getName().equals("EXPLICIT_CAST");
    }

    private static boolean isTimestampArithExpr(RexCall rexCall) {
        return rexCall.getOperator() instanceof SqlDatetimePlusOperator || rexCall.getOperator() instanceof SqlDatetimeSubtractionOperator || SqlTypeName.INTERVAL_TYPES.contains(rexCall.getType().getSqlTypeName()) || (rexCall.getOperator().equals((Object)"+") || rexCall.getOperator().equals((Object)"-")) && (SqlTypeUtil.isDatetime((RelDataType)((RexNode)rexCall.getOperands().get(0)).getType()) || SqlTypeUtil.isDatetime((RelDataType)((RexNode)rexCall.getOperands().get(1)).getType()));
    }

    private RelDataType getInputRefIndexType(List<RelNode> inputs, int index) {
        if (inputs.size() == 1) {
            return ((RelDataTypeField)inputs.get(0).getRowType().getFieldList().get(index)).getType();
        }
        Preconditions.checkState((inputs.size() == 2 ? 1 : 0) != 0);
        List leftFieldList = inputs.get(0).getRowType().getFieldList();
        if (index < leftFieldList.size()) {
            return ((RelDataTypeField)leftFieldList.get(index)).getType();
        }
        int rightIndex = index - leftFieldList.size();
        return ((RelDataTypeField)inputs.get(1).getRowType().getFieldList().get(rightIndex)).getType();
    }

    private static List<RexNode> getCastedArgTypes(Function fn, List<RexNode> operands, RelDataType retType, RelDataTypeFactory factory, RexBuilder rexBuilder) {
        List argTypes = Util.transform(operands, RexNode::getType);
        ArrayList<RexNode> newOperands = new ArrayList<RexNode>();
        boolean isCaseFunction = CoerceOperandShuttle.isCaseFunction(fn);
        boolean castedOperand = false;
        Preconditions.checkState((argTypes.size() == 0 || fn.getNumArgs() > 0 ? 1 : 0) != 0);
        for (int i = 0; i < argTypes.size(); ++i) {
            if (isCaseFunction && FunctionResolver.shouldSkipOperandForCase(argTypes.size(), i)) {
                newOperands.add(operands.get(i));
                continue;
            }
            int indexToUse = Math.min(i, fn.getNumArgs() - 1);
            Type toImpalaType = fn.getArgs()[indexToUse];
            RelDataType toType = CoerceOperandShuttle.useReturnTypeForCastingArg(fn, (RelDataType)argTypes.get(indexToUse)) ? retType : CoerceOperandShuttle.getCastedToType((RelDataType)argTypes.get(i), toImpalaType, factory);
            RexNode operand = CoerceOperandShuttle.castOperand(operands.get(i), toType, factory, rexBuilder);
            Preconditions.checkNotNull((Object)operand);
            newOperands.add(operand);
            if (operands.get(i).equals((Object)operand)) continue;
            castedOperand = true;
        }
        return castedOperand ? newOperands : operands;
    }

    private static boolean useReturnTypeForCastingArg(Function fn, RelDataType argType) {
        if (CoerceOperandShuttle.isCaseFunction(fn)) {
            return true;
        }
        return SqlTypeUtil.isDecimal((RelDataType)argType) && fn.getReturnType().isDecimal() && fn.hasVarArgs();
    }

    private static boolean isCaseFunction(Function fn) {
        return fn.functionName().equals("case");
    }

    private static RelDataType getCastedToType(RelDataType fromType, Type toImpalaType, RelDataTypeFactory factory) {
        if (toImpalaType.equals((Object)Type.CHAR) && fromType.getSqlTypeName().equals((Object)SqlTypeName.CHAR)) {
            return fromType;
        }
        if (!toImpalaType.isDecimal() || SqlTypeUtil.isNull((RelDataType)fromType)) {
            return ImpalaTypeConverter.getRelDataType(toImpalaType);
        }
        ScalarType impalaType = (ScalarType)ImpalaTypeConverter.createImpalaType(fromType);
        ScalarType decimalType = impalaType.getMinResolutionDecimal();
        return factory.createSqlType(SqlTypeName.DECIMAL, decimalType.decimalPrecision(), decimalType.decimalScale());
    }

    private static RexNode castOperand(RexNode node, RelDataType toType, RelDataTypeFactory factory, RexBuilder rexBuilder) {
        RelDataType fromType = node.getType();
        if (SqlTypeUtil.isInterval((RelDataType)fromType)) {
            return node;
        }
        if (fromType.getSqlTypeName().equals((Object)toType.getSqlTypeName()) && fromType.getPrecision() == toType.getPrecision() && fromType.getScale() == toType.getScale()) {
            return node;
        }
        if (SqlTypeUtil.isNull((RelDataType)fromType)) {
            if (SqlTypeUtil.isDecimal((RelDataType)toType)) {
                Type impalaType = ImpalaTypeConverter.createImpalaType((Type)Type.DECIMAL, 1, 0);
                toType = ImpalaTypeConverter.createRelDataType(impalaType);
            }
            return rexBuilder.makeCast(toType, node);
        }
        if (!ImplicitTypeChecker.supportsImplicitCasting(fromType, toType)) {
            return null;
        }
        return rexBuilder.makeCast(toType, node);
    }

    private static RexNode castDecodedFunction(RexCall decodeCall, RelDataTypeFactory factory, RexBuilder rexBuilder) {
        List operands = decodeCall.getOperands();
        List argTypes = Util.transform((List)operands, RexNode::getType);
        ArrayList<RexNode> newOperands = new ArrayList<RexNode>(operands.size());
        RelDataType searchOperand = ImpalaDecodeFunction.getCompatibleSearchOperand(argTypes, factory);
        RelDataType returnType = ImpalaDecodeFunction.getCompatibleReturnType(argTypes, factory);
        boolean hasElse = argTypes.size() % 2 == 0;
        int numNonElseParams = hasElse ? argTypes.size() - 1 : argTypes.size();
        newOperands.add(CoerceOperandShuttle.castOperand((RexNode)operands.get(0), searchOperand, factory, rexBuilder));
        for (int i = 1; i < numNonElseParams; ++i) {
            RelDataType toType = i % 2 == 0 ? returnType : searchOperand;
            newOperands.add(CoerceOperandShuttle.castOperand((RexNode)operands.get(i), toType, factory, rexBuilder));
        }
        if (hasElse) {
            int elseParam = operands.size() - 1;
            newOperands.add(CoerceOperandShuttle.castOperand((RexNode)operands.get(elseParam), returnType, factory, rexBuilder));
        }
        return operands.equals(newOperands) ? decodeCall : (RexCall)rexBuilder.makeCall(returnType, decodeCall.getOperator(), newOperands);
    }
}

