/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Objects;
import java.util.TreeSet;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.ImmutableLoptOptimizeJoinRule;
import org.apache.calcite.rel.rules.LoptJoinTree;
import org.apache.calcite.rel.rules.LoptMultiJoin;
import org.apache.calcite.rel.rules.LoptSemiJoinOptimizer;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.BitSets;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.mapping.IntPair;
import org.checkerframework.checker.nullness.qual.KeyFor;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

@Value.Enclosing
public class LoptOptimizeJoinRule
extends RelRule<Config>
implements TransformationRule {
    protected LoptOptimizeJoinRule(Config config) {
        super((RelRule.Config)config);
    }

    @Deprecated
    public LoptOptimizeJoinRule(RelBuilderFactory relBuilderFactory) {
        this((Config)Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class));
    }

    @Deprecated
    public LoptOptimizeJoinRule(RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, RelFactories.FilterFactory filterFactory) {
        this(RelBuilder.proto((Object[])new Object[]{joinFactory, projectFactory, filterFactory}));
    }

    public void onMatch(RelOptRuleCall call) {
        MultiJoin multiJoinRel = (MultiJoin)call.rel(0);
        LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
        RelMetadataQuery mq = call.getMetadataQuery();
        LoptOptimizeJoinRule.findRemovableOuterJoins(mq, multiJoin);
        RexBuilder rexBuilder = multiJoinRel.getCluster().getRexBuilder();
        LoptSemiJoinOptimizer semiJoinOpt = new LoptSemiJoinOptimizer(call.getMetadataQuery(), multiJoin, rexBuilder);
        semiJoinOpt.makePossibleSemiJoins(multiJoin);
        int iterations = 0;
        while (semiJoinOpt.chooseBestSemiJoin(multiJoin) && iterations++ <= 10) {
        }
        multiJoin.setFactorWeights();
        LoptOptimizeJoinRule.findRemovableSelfJoins(mq, multiJoin);
        LoptOptimizeJoinRule.findBestOrderings(mq, call.builder(), multiJoin, semiJoinOpt, call);
    }

    private static void findRemovableOuterJoins(RelMetadataQuery mq, LoptMultiJoin multiJoin) {
        ArrayList<Integer> removalCandidates = new ArrayList<Integer>();
        for (int factIdx = 0; factIdx < multiJoin.getNumJoinFactors(); ++factIdx) {
            if (!multiJoin.isNullGenerating(factIdx)) continue;
            removalCandidates.add(factIdx);
        }
        while (!removalCandidates.isEmpty()) {
            HashSet<Integer> retryCandidates = new HashSet<Integer>();
            Iterator iterator = removalCandidates.iterator();
            block2: while (iterator.hasNext()) {
                int factIdx = (Integer)iterator.next();
                ImmutableBitSet projFields = multiJoin.getProjFields(factIdx);
                if (projFields == null || projFields.cardinality() > 0) continue;
                RexNode outerJoinCond = multiJoin.getOuterJoinCond(factIdx);
                ArrayList ojFilters = new ArrayList();
                RelOptUtil.decomposeConjunction((RexNode)outerJoinCond, ojFilters);
                int numFields = multiJoin.getNumFieldsInJoinFactor(factIdx);
                ImmutableBitSet.Builder joinKeyBuilder = ImmutableBitSet.builder();
                ImmutableBitSet.Builder otherJoinKeyBuilder = ImmutableBitSet.builder();
                int firstFieldNum = multiJoin.getJoinStart(factIdx);
                int lastFieldNum = firstFieldNum + numFields;
                for (RexNode filter : ojFilters) {
                    RexCall filterCall;
                    if (!(filter instanceof RexCall) || (filterCall = (RexCall)filter).getOperator() != SqlStdOperatorTable.EQUALS || !(filterCall.getOperands().get(0) instanceof RexInputRef) || !(filterCall.getOperands().get(1) instanceof RexInputRef)) continue;
                    int leftRef = ((RexInputRef)filterCall.getOperands().get(0)).getIndex();
                    int rightRef = ((RexInputRef)filterCall.getOperands().get(1)).getIndex();
                    LoptOptimizeJoinRule.setJoinKey(joinKeyBuilder, otherJoinKeyBuilder, leftRef, rightRef, firstFieldNum, lastFieldNum, true);
                }
                if (joinKeyBuilder.cardinality() == 0) continue;
                ImmutableBitSet joinKeys = joinKeyBuilder.build();
                int[] joinFieldRefCounts = multiJoin.getJoinFieldRefCounts(factIdx);
                for (int i = 0; i < joinFieldRefCounts.length; ++i) {
                    if (joinFieldRefCounts[i] > 1 || !joinKeys.get(i) && joinFieldRefCounts[i] == 1) continue block2;
                }
                if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered((RelMetadataQuery)mq, (RelNode)multiJoin.getJoinFactor(factIdx), (ImmutableBitSet)joinKeys)) continue;
                multiJoin.addRemovableOuterJoinFactor(factIdx);
                ImmutableBitSet otherJoinKeys = otherJoinKeyBuilder.build();
                Iterator iterator2 = otherJoinKeys.iterator();
                while (iterator2.hasNext()) {
                    int otherKey = (Integer)iterator2.next();
                    int otherFactor = multiJoin.findRef(otherKey);
                    if (multiJoin.isNullGenerating(otherFactor)) {
                        retryCandidates.add(otherFactor);
                    }
                    int[] otherJoinFieldRefCounts = multiJoin.getJoinFieldRefCounts(otherFactor);
                    int offset = multiJoin.getJoinStart(otherFactor);
                    int n = otherKey - offset;
                    otherJoinFieldRefCounts[n] = otherJoinFieldRefCounts[n] - 1;
                }
            }
            removalCandidates.clear();
            removalCandidates.addAll(retryCandidates);
        }
    }

    private static void setJoinKey(ImmutableBitSet.Builder joinKeys, ImmutableBitSet.Builder otherJoinKeys, int ref1, int ref2, int firstFieldNum, int lastFieldNum, boolean swap) {
        if (ref1 >= firstFieldNum && ref1 < lastFieldNum) {
            if (ref2 < firstFieldNum || ref2 >= lastFieldNum) {
                joinKeys.set(ref1 - firstFieldNum);
                otherJoinKeys.set(ref2);
            }
            return;
        }
        if (swap) {
            LoptOptimizeJoinRule.setJoinKey(joinKeys, otherJoinKeys, ref2, ref1, firstFieldNum, lastFieldNum, false);
        }
    }

    private static void findRemovableSelfJoins(RelMetadataQuery mq, LoptMultiJoin multiJoin) {
        Map<Integer, RelOptTable> simpleFactors = LoptOptimizeJoinRule.getSimpleFactors(mq, multiJoin);
        ArrayList<RelOptTable> repeatedTables = new ArrayList<RelOptTable>();
        HashMap<Integer, Integer> selfJoinPairs = new HashMap<Integer, Integer>();
        @KeyFor(value={"simpleFactors"}) Integer[] factors = new TreeSet<Integer>(simpleFactors.keySet()).toArray(new Integer[0]);
        block0: for (int i = 0; i < factors.length; ++i) {
            if (repeatedTables.contains(simpleFactors.get(factors[i]))) continue;
            for (int j = i + 1; j < factors.length; ++j) {
                @KeyFor(value={"simpleFactors"}) int leftFactor = factors[i];
                @KeyFor(value={"simpleFactors"}) int rightFactor = factors[j];
                if (!simpleFactors.get(leftFactor).getQualifiedName().equals(simpleFactors.get(rightFactor).getQualifiedName())) continue;
                selfJoinPairs.put(leftFactor, rightFactor);
                repeatedTables.add(simpleFactors.get(leftFactor));
                continue block0;
            }
        }
        for (Integer factor1 : selfJoinPairs.keySet()) {
            int factor2 = (Integer)selfJoinPairs.get(factor1);
            ArrayList<RexNode> selfJoinFilters = new ArrayList<RexNode>();
            for (RexNode filter : multiJoin.getJoinFilters()) {
                ImmutableBitSet joinFactors = multiJoin.getFactorsRefByJoinFilter(filter);
                if (joinFactors.cardinality() != 2 || !joinFactors.get(factor1.intValue()) || !joinFactors.get(factor2)) continue;
                selfJoinFilters.add(filter);
            }
            if (selfJoinFilters.size() <= 0 || !LoptOptimizeJoinRule.isSelfJoinFilterUnique(mq, multiJoin, factor1, factor2, selfJoinFilters)) continue;
            multiJoin.addRemovableSelfJoinPair(factor1.intValue(), factor2);
        }
    }

    private static Map<Integer, RelOptTable> getSimpleFactors(RelMetadataQuery mq, LoptMultiJoin multiJoin) {
        HashMap<Integer, RelOptTable> returnList = new HashMap<Integer, RelOptTable>();
        if (multiJoin.getMultiJoinRel().isFullOuterJoin()) {
            return returnList;
        }
        for (int factIdx = 0; factIdx < multiJoin.getNumJoinFactors(); ++factIdx) {
            RelNode rel;
            RelOptTable table;
            if (multiJoin.isNullGenerating(factIdx) || multiJoin.getJoinRemovalFactor(factIdx) != null || (table = mq.getTableOrigin(rel = multiJoin.getJoinFactor(factIdx))) == null) continue;
            returnList.put(factIdx, table);
        }
        return returnList;
    }

    private static boolean isSelfJoinFilterUnique(RelMetadataQuery mq, LoptMultiJoin multiJoin, int leftFactor, int rightFactor, List<RexNode> joinFilterList) {
        RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        RelNode leftRel = multiJoin.getJoinFactor(leftFactor);
        RelNode rightRel = multiJoin.getJoinFactor(rightFactor);
        RexNode joinFilters = RexUtil.composeConjunction((RexBuilder)rexBuilder, joinFilterList);
        int[] adjustments = new int[multiJoin.getNumTotalFields()];
        int leftAdjust = multiJoin.getJoinStart(leftFactor);
        int nLeftFields = leftRel.getRowType().getFieldCount();
        for (int i = 0; i < nLeftFields; ++i) {
            adjustments[leftAdjust + i] = -leftAdjust;
        }
        int rightAdjust = multiJoin.getJoinStart(rightFactor);
        for (int i = 0; i < rightRel.getRowType().getFieldCount(); ++i) {
            adjustments[rightAdjust + i] = -rightAdjust + nLeftFields;
        }
        joinFilters = (RexNode)joinFilters.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, multiJoin.getMultiJoinFields(), leftRel.getRowType().getFieldList(), rightRel.getRowType().getFieldList(), adjustments));
        return LoptOptimizeJoinRule.areSelfJoinKeysUnique(mq, leftRel, rightRel, joinFilters);
    }

    private static void findBestOrderings(RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, RelOptRuleCall call) {
        ArrayList<RelNode> plans = new ArrayList<RelNode>();
        List fieldNames = multiJoin.getMultiJoinRel().getRowType().getFieldNames();
        for (int i = 0; i < multiJoin.getNumJoinFactors(); ++i) {
            LoptJoinTree joinTree;
            if (multiJoin.isNullGenerating(i) || (joinTree = LoptOptimizeJoinRule.createOrdering(mq, relBuilder, multiJoin, semiJoinOpt, i)) == null) continue;
            RelNode newProject = LoptOptimizeJoinRule.createTopProject(call.builder(), multiJoin, joinTree, fieldNames);
            plans.add(newProject);
        }
        for (RelNode plan : plans) {
            call.transformTo(plan);
        }
    }

    private static RelNode createTopProject(RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptJoinTree joinTree, List<String> fieldNames) {
        ArrayList<RexInputRef> newProjExprs = new ArrayList<RexInputRef>();
        RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        List newJoinOrder = joinTree.getTreeOrder();
        int nJoinFactors = multiJoin.getNumJoinFactors();
        List fields = multiJoin.getMultiJoinFields();
        HashMap factorToOffsetMap = new HashMap();
        int fieldStart = 0;
        for (int pos = 0; pos < nJoinFactors; ++pos) {
            factorToOffsetMap.put(newJoinOrder.get(pos), fieldStart);
            fieldStart += multiJoin.getNumFieldsInJoinFactor(((Integer)newJoinOrder.get(pos)).intValue());
        }
        for (int currFactor = 0; currFactor < nJoinFactors; ++currFactor) {
            Integer leftFactor = null;
            if (multiJoin.isRightFactorInRemovableSelfJoin(currFactor)) {
                leftFactor = multiJoin.getOtherSelfJoinFactor(currFactor);
            }
            for (int fieldPos = 0; fieldPos < multiJoin.getNumFieldsInJoinFactor(currFactor); ++fieldPos) {
                Integer leftOffset;
                int newOffset = (Integer)Objects.requireNonNull(factorToOffsetMap.get(currFactor), "factorToOffsetMap.get(currFactor)") + fieldPos;
                if (leftFactor != null && (leftOffset = multiJoin.getRightColumnMapping(currFactor, fieldPos)) != null) {
                    newOffset = (Integer)Objects.requireNonNull(factorToOffsetMap.get(leftFactor), "factorToOffsetMap.get(leftFactor)") + leftOffset;
                }
                newProjExprs.add(rexBuilder.makeInputRef(((RelDataTypeField)fields.get(newProjExprs.size())).getType(), newOffset));
            }
        }
        relBuilder.push(joinTree.getJoinTree());
        relBuilder.project(newProjExprs, fieldNames);
        RexNode postJoinFilter = multiJoin.getMultiJoinRel().getPostJoinFilter();
        if (postJoinFilter != null) {
            relBuilder.filter(new RexNode[]{postJoinFilter});
        }
        return relBuilder.build();
    }

    private static @Nullable Double computeJoinCardinality(RelMetadataQuery mq, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, LoptJoinTree joinTree, List<RexNode> filters, int factor) {
        ImmutableBitSet childFactors = ImmutableBitSet.builder().addAll((Iterable)joinTree.getTreeOrder()).set(factor).build();
        int factorStart = multiJoin.getJoinStart(factor);
        int nFields = multiJoin.getNumFieldsInJoinFactor(factor);
        ImmutableBitSet.Builder joinKeys = ImmutableBitSet.builder();
        LoptOptimizeJoinRule.setFactorJoinKeys(multiJoin, filters, childFactors, factorStart, nFields, joinKeys);
        LoptOptimizeJoinRule.setFactorJoinKeys(multiJoin, RelOptUtil.conjunctions((RexNode)multiJoin.getOuterJoinCond(factor)), childFactors, factorStart, nFields, joinKeys);
        if (joinKeys.isEmpty()) {
            return null;
        }
        return mq.getDistinctRowCount(semiJoinOpt.getChosenSemiJoin(factor), joinKeys.build(), null);
    }

    private static void setFactorJoinKeys(LoptMultiJoin multiJoin, List<RexNode> filters, ImmutableBitSet joinFactors, int factorStart, int nFields, ImmutableBitSet.Builder joinKeys) {
        for (RexNode joinFilter : filters) {
            ImmutableBitSet filterFactors = multiJoin.getFactorsRefByJoinFilter(joinFilter);
            if (!joinFactors.contains(filterFactors)) continue;
            ImmutableBitSet joinFields = multiJoin.getFieldsRefByJoinFilter(joinFilter);
            int field = joinFields.nextSetBit(factorStart);
            while (field >= 0 && field < factorStart + nFields) {
                joinKeys.set(field - factorStart);
                field = joinFields.nextSetBit(field + 1);
            }
        }
    }

    private static @Nullable LoptJoinTree createOrdering(RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, int firstFactor) {
        LoptJoinTree joinTree = null;
        int nJoinFactors = multiJoin.getNumJoinFactors();
        BitSet factorsToAdd = BitSets.range((int)0, (int)nJoinFactors);
        BitSet factorsAdded = new BitSet(nJoinFactors);
        ArrayList<RexNode> filtersToAdd = new ArrayList<RexNode>(multiJoin.getJoinFilters());
        int prevFactor = -1;
        while (factorsToAdd.cardinality() > 0) {
            int nextFactor;
            boolean selfJoin = false;
            if (factorsAdded.cardinality() == 0) {
                nextFactor = firstFactor;
            } else {
                Integer selfJoinFactor = multiJoin.getOtherSelfJoinFactor(prevFactor);
                if (selfJoinFactor != null && !factorsAdded.get(selfJoinFactor)) {
                    nextFactor = selfJoinFactor;
                    selfJoin = true;
                } else {
                    nextFactor = LoptOptimizeJoinRule.getBestNextFactor(mq, multiJoin, factorsToAdd, factorsAdded, semiJoinOpt, joinTree, filtersToAdd);
                }
            }
            BitSet factorsNeeded = multiJoin.getFactorsRefByFactor(nextFactor).toBitSet();
            if (multiJoin.isNullGenerating(nextFactor)) {
                factorsNeeded.or(multiJoin.getOuterJoinFactors(nextFactor).toBitSet());
            }
            factorsNeeded.and(factorsAdded);
            joinTree = LoptOptimizeJoinRule.addFactorToTree(mq, relBuilder, multiJoin, semiJoinOpt, joinTree, nextFactor, factorsNeeded, filtersToAdd, selfJoin);
            if (joinTree == null) {
                return null;
            }
            factorsToAdd.clear(nextFactor);
            factorsAdded.set(nextFactor);
            prevFactor = nextFactor;
        }
        assert (filtersToAdd.size() == 0);
        return joinTree;
    }

    private static int getBestNextFactor(RelMetadataQuery mq, LoptMultiJoin multiJoin, BitSet factorsToAdd, BitSet factorsAdded, LoptSemiJoinOptimizer semiJoinOpt, @Nullable LoptJoinTree joinTree, List<RexNode> filtersToAdd) {
        int nextFactor = -1;
        int bestWeight = 0;
        Double bestCardinality = null;
        int[][] factorWeights = multiJoin.getFactorWeights();
        Iterator iterator = BitSets.toIter((BitSet)factorsToAdd).iterator();
        while (iterator.hasNext()) {
            int factor = (Integer)iterator.next();
            Integer factIdx = multiJoin.getJoinRemovalFactor(factor);
            if (factIdx != null && !factorsAdded.get(factIdx) || multiJoin.isNullGenerating(factor) && !BitSets.contains((BitSet)factorsAdded, (ImmutableBitSet)multiJoin.getOuterJoinFactors(factor))) continue;
            int dimWeight = 0;
            Iterator iterator2 = BitSets.toIter((BitSet)factorsAdded).iterator();
            while (iterator2.hasNext()) {
                int prevFactor = (Integer)iterator2.next();
                int[] factorWeight = Objects.requireNonNull(factorWeights, "factorWeights")[prevFactor];
                if (factorWeight[factor] <= dimWeight) continue;
                dimWeight = factorWeight[factor];
            }
            Double cardinality = null;
            if (dimWeight > 0 && (dimWeight > bestWeight || dimWeight == bestWeight)) {
                cardinality = LoptOptimizeJoinRule.computeJoinCardinality(mq, multiJoin, semiJoinOpt, Objects.requireNonNull(joinTree, "joinTree"), filtersToAdd, factor);
            }
            if (dimWeight <= bestWeight && (dimWeight != bestWeight || bestCardinality != null && (cardinality == null || !(cardinality > bestCardinality)))) continue;
            nextFactor = factor;
            bestWeight = dimWeight;
            bestCardinality = cardinality;
        }
        return nextFactor;
    }

    private static boolean isJoinTree(RelNode rel) {
        if (rel instanceof Join) {
            assert (((Join)rel).getJoinType() != JoinRelType.FULL);
            return true;
        }
        return false;
    }

    private static @Nullable LoptJoinTree addFactorToTree(RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, @Nullable LoptJoinTree joinTree, int factorToAdd, BitSet factorsNeeded, List<RexNode> filtersToAdd, boolean selfJoin) {
        LoptJoinTree bestTree;
        if (multiJoin.isRemovableOuterJoinFactor(factorToAdd)) {
            return LoptOptimizeJoinRule.createReplacementJoin(relBuilder, multiJoin, semiJoinOpt, Objects.requireNonNull(joinTree, "joinTree"), -1, factorToAdd, ImmutableIntList.of(), null, filtersToAdd);
        }
        if (multiJoin.getJoinRemovalFactor(factorToAdd) != null) {
            return LoptOptimizeJoinRule.createReplacementSemiJoin(relBuilder, multiJoin, semiJoinOpt, joinTree, factorToAdd, filtersToAdd);
        }
        if (joinTree == null) {
            return new LoptJoinTree(semiJoinOpt.getChosenSemiJoin(factorToAdd), factorToAdd);
        }
        ArrayList<RexNode> tmpFilters = new ArrayList<RexNode>(filtersToAdd);
        LoptJoinTree topTree = LoptOptimizeJoinRule.addToTop(mq, relBuilder, multiJoin, semiJoinOpt, joinTree, factorToAdd, filtersToAdd, selfJoin);
        LoptJoinTree pushDownTree = LoptOptimizeJoinRule.pushDownFactor(mq, relBuilder, multiJoin, semiJoinOpt, joinTree, factorToAdd, factorsNeeded, topTree == null ? filtersToAdd : tmpFilters, selfJoin);
        RelOptCost costPushDown = null;
        RelOptCost costTop = null;
        if (pushDownTree != null) {
            costPushDown = mq.getCumulativeCost(pushDownTree.getJoinTree());
        }
        if (topTree != null) {
            costTop = mq.getCumulativeCost(topTree.getJoinTree());
        }
        if (pushDownTree == null) {
            bestTree = topTree;
        } else if (topTree == null) {
            bestTree = pushDownTree;
        } else {
            Objects.requireNonNull(costPushDown, "costPushDown");
            Objects.requireNonNull(costTop, "costTop");
            bestTree = costPushDown.isEqWithEpsilon(costTop) ? (LoptOptimizeJoinRule.rowWidthCost(pushDownTree.getJoinTree()) < LoptOptimizeJoinRule.rowWidthCost(topTree.getJoinTree()) ? pushDownTree : topTree) : (costPushDown.isLt(costTop) ? pushDownTree : topTree);
        }
        return bestTree;
    }

    private static int rowWidthCost(RelNode tree) {
        int width = tree.getRowType().getFieldCount();
        if (LoptOptimizeJoinRule.isJoinTree(tree)) {
            Join joinRel = (Join)tree;
            width += LoptOptimizeJoinRule.rowWidthCost(joinRel.getLeft()) + LoptOptimizeJoinRule.rowWidthCost(joinRel.getRight());
        }
        return width;
    }

    private static @Nullable LoptJoinTree pushDownFactor(RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, LoptJoinTree joinTree, int factorToAdd, BitSet factorsNeeded, List<RexNode> filtersToAdd, boolean selfJoin) {
        if (!LoptOptimizeJoinRule.isJoinTree(joinTree.getJoinTree())) {
            return null;
        }
        int childNo = -1;
        LoptJoinTree left = joinTree.getLeft();
        LoptJoinTree right = joinTree.getRight();
        Join joinRel = (Join)joinTree.getJoinTree();
        JoinRelType joinType = joinRel.getJoinType();
        if (joinTree.isRemovableSelfJoin()) {
            return null;
        }
        if (selfJoin) {
            BitSet selfJoinFactor = new BitSet(multiJoin.getNumJoinFactors());
            Integer factor = Objects.requireNonNull(multiJoin.getOtherSelfJoinFactor(factorToAdd), () -> "multiJoin.getOtherSelfJoinFactor(" + factorToAdd + ") is null");
            selfJoinFactor.set(factor);
            if (multiJoin.hasAllFactors(left, selfJoinFactor)) {
                childNo = 0;
            } else {
                assert (multiJoin.hasAllFactors(right, selfJoinFactor));
                childNo = 1;
            }
        } else if (factorsNeeded.cardinality() == 0 && !joinType.generatesNullsOnLeft()) {
            childNo = 0;
        } else if (multiJoin.hasAllFactors(left, factorsNeeded) && !joinType.generatesNullsOnLeft()) {
            childNo = 0;
        } else if (multiJoin.hasAllFactors(right, factorsNeeded) && !joinType.generatesNullsOnRight()) {
            childNo = 1;
        }
        if (childNo == -1) {
            return null;
        }
        List origJoinOrder = joinTree.getTreeOrder();
        LoptJoinTree subTree = childNo == 0 ? left : right;
        subTree = LoptOptimizeJoinRule.addFactorToTree(mq, relBuilder, multiJoin, semiJoinOpt, subTree, factorToAdd, factorsNeeded, filtersToAdd, selfJoin);
        if (childNo == 0) {
            left = subTree;
        } else {
            right = subTree;
        }
        RexNode newCondition = ((Join)joinTree.getJoinTree()).getCondition();
        newCondition = LoptOptimizeJoinRule.adjustFilter(multiJoin, Objects.requireNonNull(left, "left"), Objects.requireNonNull(right, "right"), newCondition, factorToAdd, origJoinOrder, joinTree.getJoinTree().getRowType().getFieldList());
        if (joinType != JoinRelType.LEFT && joinType != JoinRelType.RIGHT) {
            RexNode condition = LoptOptimizeJoinRule.addFilters(multiJoin, left, -1, right, filtersToAdd, true);
            RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
            newCondition = RelOptUtil.andJoinFilters((RexBuilder)rexBuilder, (RexNode)newCondition, (RexNode)condition);
        }
        return LoptOptimizeJoinRule.createJoinSubtree(mq, relBuilder, multiJoin, left, right, newCondition, joinType, filtersToAdd, false, false);
    }

    private static @Nullable LoptJoinTree addToTop(RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, LoptJoinTree joinTree, int factorToAdd, List<RexNode> filtersToAdd, boolean selfJoin) {
        JoinRelType joinType;
        if (selfJoin && LoptOptimizeJoinRule.isJoinTree(joinTree.getJoinTree())) {
            return null;
        }
        if (multiJoin.getMultiJoinRel().isFullOuterJoin()) {
            assert (multiJoin.getNumJoinFactors() == 2);
            joinType = JoinRelType.FULL;
        } else {
            joinType = multiJoin.isNullGenerating(factorToAdd) ? JoinRelType.LEFT : JoinRelType.INNER;
        }
        LoptJoinTree rightTree = new LoptJoinTree(semiJoinOpt.getChosenSemiJoin(factorToAdd), factorToAdd);
        RexNode condition = joinType == JoinRelType.LEFT || joinType == JoinRelType.RIGHT ? Objects.requireNonNull(multiJoin.getOuterJoinCond(factorToAdd), "multiJoin.getOuterJoinCond(factorToAdd)") : LoptOptimizeJoinRule.addFilters(multiJoin, joinTree, -1, rightTree, filtersToAdd, false);
        return LoptOptimizeJoinRule.createJoinSubtree(mq, relBuilder, multiJoin, joinTree, rightTree, condition, joinType, filtersToAdd, true, selfJoin);
    }

    private static RexNode addFilters(LoptMultiJoin multiJoin, LoptJoinTree leftTree, int leftIdx, LoptJoinTree rightTree, List<RexNode> filtersToAdd, boolean adjust) {
        int[] adjustments;
        RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        ImmutableBitSet.Builder childFactorBuilder = ImmutableBitSet.builder();
        childFactorBuilder.addAll((Iterable)rightTree.getTreeOrder());
        if (leftIdx >= 0) {
            childFactorBuilder.set(leftIdx);
        } else {
            childFactorBuilder.addAll((Iterable)leftTree.getTreeOrder());
        }
        Iterator iterator = rightTree.getTreeOrder().iterator();
        while (iterator.hasNext()) {
            int child = (Integer)iterator.next();
            childFactorBuilder.set(child);
        }
        ImmutableBitSet childFactor = childFactorBuilder.build();
        RexLiteral condition = null;
        ListIterator<RexNode> filterIter = filtersToAdd.listIterator();
        while (filterIter.hasNext()) {
            RexNode joinFilter = filterIter.next();
            ImmutableBitSet filterBitmap = multiJoin.getFactorsRefByJoinFilter(joinFilter);
            if (!childFactor.contains(filterBitmap)) continue;
            condition = condition == null ? joinFilter : rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, new RexNode[]{condition, joinFilter});
            filterIter.remove();
        }
        if (adjust && condition != null && LoptOptimizeJoinRule.needsAdjustment(multiJoin, adjustments = new int[multiJoin.getNumTotalFields()], leftTree, rightTree, false)) {
            condition = (RexNode)condition.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, multiJoin.getMultiJoinFields(), leftTree.getJoinTree().getRowType().getFieldList(), rightTree.getJoinTree().getRowType().getFieldList(), adjustments));
        }
        if (condition == null) {
            condition = rexBuilder.makeLiteral(true);
        }
        return condition;
    }

    private static RexNode adjustFilter(LoptMultiJoin multiJoin, LoptJoinTree left, LoptJoinTree right, RexNode condition, int factorAdded, List<Integer> origJoinOrder, List<RelDataTypeField> origFields) {
        ArrayList<Integer> newJoinOrder = new ArrayList<Integer>();
        left.getTreeOrder(newJoinOrder);
        right.getTreeOrder(newJoinOrder);
        int totalFields = left.getJoinTree().getRowType().getFieldCount() + right.getJoinTree().getRowType().getFieldCount() - multiJoin.getNumFieldsInJoinFactor(factorAdded);
        int[] adjustments = new int[totalFields];
        boolean needAdjust = false;
        int nFieldsNew = 0;
        for (int newPos = 0; newPos < newJoinOrder.size(); ++newPos) {
            int nFieldsOld = 0;
            int factor = (Integer)newJoinOrder.get(newPos);
            if (factor != factorAdded) {
                int pos;
                Iterator<Integer> iterator = origJoinOrder.iterator();
                while (iterator.hasNext() && factor != (pos = iterator.next().intValue())) {
                    nFieldsOld += multiJoin.getNumFieldsInJoinFactor(pos);
                }
                if (LoptOptimizeJoinRule.remapJoinReferences(multiJoin, factor, newJoinOrder, newPos, adjustments, nFieldsOld, nFieldsNew, false)) {
                    needAdjust = true;
                }
            }
            nFieldsNew += multiJoin.getNumFieldsInJoinFactor(factor);
        }
        if (needAdjust) {
            RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
            condition = (RexNode)condition.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, origFields, left.getJoinTree().getRowType().getFieldList(), right.getJoinTree().getRowType().getFieldList(), adjustments));
        }
        return condition;
    }

    private static boolean remapJoinReferences(LoptMultiJoin multiJoin, int factor, List<Integer> newJoinOrder, int newPos, int[] adjustments, int offset, int newOffset, boolean alwaysUseDefault) {
        boolean needAdjust;
        block3: {
            int defaultAdjustment;
            block2: {
                needAdjust = false;
                defaultAdjustment = -offset + newOffset;
                if (alwaysUseDefault || !multiJoin.isRightFactorInRemovableSelfJoin(factor) || newPos == 0 || !newJoinOrder.get(newPos - 1).equals(multiJoin.getOtherSelfJoinFactor(factor))) break block2;
                int nLeftFields = multiJoin.getNumFieldsInJoinFactor(newJoinOrder.get(newPos - 1).intValue());
                for (int i = 0; i < multiJoin.getNumFieldsInJoinFactor(factor); ++i) {
                    Integer leftOffset = multiJoin.getRightColumnMapping(factor, i);
                    adjustments[i + offset] = leftOffset == null ? defaultAdjustment : -(offset + i) + (newOffset - nLeftFields) + leftOffset;
                    if (adjustments[i + offset] == 0) continue;
                    needAdjust = true;
                }
                break block3;
            }
            if (defaultAdjustment == 0) break block3;
            needAdjust = true;
            for (int i = 0; i < multiJoin.getNumFieldsInJoinFactor(newJoinOrder.get(newPos).intValue()); ++i) {
                adjustments[i + offset] = defaultAdjustment;
            }
        }
        return needAdjust;
    }

    private static @Nullable LoptJoinTree createReplacementSemiJoin(RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, @Nullable LoptJoinTree factTree, int dimIdx, List<RexNode> filtersToAdd) {
        Integer factor;
        if (factTree == null) {
            return null;
        }
        int factIdx = Objects.requireNonNull(multiJoin.getJoinRemovalFactor(dimIdx), () -> "multiJoin.getJoinRemovalFactor(dimIdx) for " + dimIdx + ", " + multiJoin);
        List joinOrder = factTree.getTreeOrder();
        assert (joinOrder.contains(factIdx));
        int adjustment = 0;
        Iterator iterator = joinOrder.iterator();
        while (iterator.hasNext() && (factor = (Integer)iterator.next()) != factIdx) {
            adjustment += multiJoin.getNumFieldsInJoinFactor(factor.intValue());
        }
        List dimFields = multiJoin.getJoinFactor(dimIdx).getRowType().getFieldList();
        int nDimFields = dimFields.size();
        Integer[] replacementKeys = new Integer[nDimFields];
        LogicalJoin semiJoin = multiJoin.getJoinRemovalSemiJoin(dimIdx);
        ImmutableIntList dimKeys = semiJoin.analyzeCondition().leftKeys;
        ImmutableIntList factKeys = semiJoin.analyzeCondition().rightKeys;
        for (int i = 0; i < dimKeys.size(); ++i) {
            replacementKeys[dimKeys.get((int)i).intValue()] = factKeys.get(i) + adjustment;
        }
        return LoptOptimizeJoinRule.createReplacementJoin(relBuilder, multiJoin, semiJoinOpt, factTree, factIdx, dimIdx, dimKeys, replacementKeys, filtersToAdd);
    }

    private static LoptJoinTree createReplacementJoin(RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, LoptJoinTree currJoinTree, int leftIdx, int factorToAdd, ImmutableIntList newKeys, Integer @Nullable [] replacementKeys, List<RexNode> filtersToAdd) {
        int i;
        RelNode currJoinRel = currJoinTree.getJoinTree();
        List currFields = currJoinRel.getRowType().getFieldList();
        int nCurrFields = currFields.size();
        List newFields = multiJoin.getJoinFactor(factorToAdd).getRowType().getFieldList();
        int nNewFields = newFields.size();
        ArrayList<Pair> projects = new ArrayList<Pair>();
        RexBuilder rexBuilder = currJoinRel.getCluster().getRexBuilder();
        RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
        for (i = 0; i < nCurrFields; ++i) {
            projects.add(Pair.of((Object)rexBuilder.makeInputRef(((RelDataTypeField)currFields.get(i)).getType(), i), (Object)((RelDataTypeField)currFields.get(i)).getName()));
        }
        for (i = 0; i < nNewFields; ++i) {
            Object projExpr;
            RelDataType newType = ((RelDataTypeField)newFields.get(i)).getType();
            if (!newKeys.contains((Object)i)) {
                if (replacementKeys == null) {
                    newType = typeFactory.createTypeWithNullability(newType, true);
                }
                projExpr = rexBuilder.makeNullLiteral(newType);
            } else {
                Objects.requireNonNull(replacementKeys, "replacementKeys");
                RelDataTypeField mappedField = (RelDataTypeField)currFields.get(replacementKeys[i]);
                RexInputRef mappedInput = rexBuilder.makeInputRef(mappedField.getType(), replacementKeys[i].intValue());
                projExpr = mappedField.getType() == newType ? mappedInput : rexBuilder.makeCast(((RelDataTypeField)newFields.get(i)).getType(), (RexNode)mappedInput);
            }
            projects.add(Pair.of((Object)projExpr, (Object)((RelDataTypeField)newFields.get(i)).getName()));
        }
        relBuilder.push(currJoinRel);
        relBuilder.project((Iterable)Pair.left(projects), (Iterable)Pair.right(projects));
        LoptJoinTree newTree = new LoptJoinTree(semiJoinOpt.getChosenSemiJoin(factorToAdd), factorToAdd);
        LoptOptimizeJoinRule.addFilters(multiJoin, currJoinTree, leftIdx, newTree, filtersToAdd, false);
        if (leftIdx >= 0) {
            LoptOptimizeJoinRule.addAdditionalFilters(relBuilder, multiJoin, currJoinTree, newTree, filtersToAdd);
        }
        return new LoptJoinTree(relBuilder.build(), currJoinTree.getFactorTree(), newTree.getFactorTree());
    }

    private static LoptJoinTree createJoinSubtree(RelMetadataQuery mq, RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptJoinTree left, LoptJoinTree right, RexNode condition, JoinRelType joinType, List<RexNode> filtersToAdd, boolean fullAdjust, boolean selfJoin) {
        int[] adjustments;
        RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        if (LoptOptimizeJoinRule.swapInputs(mq, multiJoin, left, right, selfJoin)) {
            LoptJoinTree tmp = right;
            right = left;
            left = tmp;
            if (!fullAdjust) {
                condition = LoptOptimizeJoinRule.swapFilter(rexBuilder, multiJoin, right, left, condition);
            }
            if (joinType != JoinRelType.INNER && joinType != JoinRelType.FULL) {
                JoinRelType joinRelType = joinType = joinType == JoinRelType.LEFT ? JoinRelType.RIGHT : JoinRelType.LEFT;
            }
        }
        if (fullAdjust && LoptOptimizeJoinRule.needsAdjustment(multiJoin, adjustments = new int[multiJoin.getNumTotalFields()], left, right, selfJoin)) {
            condition = (RexNode)condition.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, multiJoin.getMultiJoinFields(), left.getJoinTree().getRowType().getFieldList(), right.getJoinTree().getRowType().getFieldList(), adjustments));
        }
        relBuilder.push(left.getJoinTree()).push(right.getJoinTree()).join(joinType, condition);
        if (joinType == JoinRelType.LEFT || joinType == JoinRelType.RIGHT) {
            assert (!selfJoin);
            LoptOptimizeJoinRule.addAdditionalFilters(relBuilder, multiJoin, left, right, filtersToAdd);
        }
        return new LoptJoinTree(relBuilder.build(), left.getFactorTree(), right.getFactorTree(), selfJoin);
    }

    private static void addAdditionalFilters(RelBuilder relBuilder, LoptMultiJoin multiJoin, LoptJoinTree left, LoptJoinTree right, List<RexNode> filtersToAdd) {
        int[] adjustments;
        RexNode filterCond = LoptOptimizeJoinRule.addFilters(multiJoin, left, -1, right, filtersToAdd, false);
        if (!filterCond.isAlwaysTrue() && LoptOptimizeJoinRule.needsAdjustment(multiJoin, adjustments = new int[multiJoin.getNumTotalFields()], left, right, false)) {
            RexBuilder rexBuilder = multiJoin.getMultiJoinRel().getCluster().getRexBuilder();
            filterCond = (RexNode)filterCond.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, multiJoin.getMultiJoinFields(), relBuilder.peek().getRowType().getFieldList(), adjustments));
            relBuilder.filter(new RexNode[]{filterCond});
        }
    }

    private static boolean swapInputs(RelMetadataQuery mq, LoptMultiJoin multiJoin, LoptJoinTree left, LoptJoinTree right, boolean selfJoin) {
        boolean swap = false;
        if (selfJoin) {
            return !multiJoin.isLeftFactorInRemovableSelfJoin(((LoptJoinTree.Leaf)left.getFactorTree()).getId());
        }
        Double leftRowCount = mq.getRowCount(left.getJoinTree());
        Double rightRowCount = mq.getRowCount(right.getJoinTree());
        if (leftRowCount != null && rightRowCount != null && (leftRowCount < rightRowCount || Math.abs(leftRowCount - rightRowCount) < 1.0E-5 && LoptOptimizeJoinRule.rowWidthCost(left.getJoinTree()) < LoptOptimizeJoinRule.rowWidthCost(right.getJoinTree()))) {
            swap = true;
        }
        return swap;
    }

    private static RexNode swapFilter(RexBuilder rexBuilder, LoptMultiJoin multiJoin, LoptJoinTree origLeft, LoptJoinTree origRight, RexNode condition) {
        int i;
        int nFieldsOnLeft = origLeft.getJoinTree().getRowType().getFieldCount();
        int nFieldsOnRight = origRight.getJoinTree().getRowType().getFieldCount();
        int[] adjustments = new int[nFieldsOnLeft + nFieldsOnRight];
        for (i = 0; i < nFieldsOnLeft; ++i) {
            adjustments[i] = nFieldsOnRight;
        }
        for (i = nFieldsOnLeft; i < nFieldsOnLeft + nFieldsOnRight; ++i) {
            adjustments[i] = -nFieldsOnLeft;
        }
        condition = (RexNode)condition.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, multiJoin.getJoinFields(origLeft, origRight), multiJoin.getJoinFields(origRight, origLeft), adjustments));
        return condition;
    }

    private static boolean needsAdjustment(LoptMultiJoin multiJoin, int[] adjustments, LoptJoinTree joinTree, LoptJoinTree otherTree, boolean selfJoin) {
        boolean needAdjustment = false;
        ArrayList<Integer> joinOrder = new ArrayList<Integer>();
        joinTree.getTreeOrder(joinOrder);
        if (otherTree != null) {
            otherTree.getTreeOrder(joinOrder);
        }
        int nFields = 0;
        for (int newPos = 0; newPos < joinOrder.size(); ++newPos) {
            int joinStart;
            int origPos = (Integer)joinOrder.get(newPos);
            if (LoptOptimizeJoinRule.remapJoinReferences(multiJoin, origPos, joinOrder, newPos, adjustments, joinStart = multiJoin.getJoinStart(origPos), nFields, selfJoin)) {
                needAdjustment = true;
            }
            nFields += multiJoin.getNumFieldsInJoinFactor(origPos);
        }
        return needAdjustment;
    }

    public static boolean isRemovableSelfJoin(Join joinRel) {
        RelNode left = joinRel.getLeft();
        RelNode right = joinRel.getRight();
        if (joinRel.getJoinType().isOuterJoin()) {
            return false;
        }
        RelMetadataQuery mq = joinRel.getCluster().getMetadataQuery();
        RelOptTable leftTable = mq.getTableOrigin(left);
        if (leftTable == null) {
            return false;
        }
        RelOptTable rightTable = mq.getTableOrigin(right);
        if (rightTable == null) {
            return false;
        }
        if (!leftTable.getQualifiedName().equals(rightTable.getQualifiedName())) {
            return false;
        }
        return LoptOptimizeJoinRule.areSelfJoinKeysUnique(mq, left, right, joinRel.getCondition());
    }

    private static boolean areSelfJoinKeysUnique(RelMetadataQuery mq, RelNode leftRel, RelNode rightRel, RexNode joinFilters) {
        JoinInfo joinInfo = JoinInfo.of((RelNode)leftRel, (RelNode)rightRel, (RexNode)joinFilters);
        for (IntPair pair : joinInfo.pairs()) {
            RelColumnOrigin leftOrigin = mq.getColumnOrigin(leftRel, pair.source);
            if (leftOrigin == null || !leftOrigin.isDerived()) {
                return false;
            }
            RelColumnOrigin rightOrigin = mq.getColumnOrigin(rightRel, pair.target);
            if (rightOrigin == null || !rightOrigin.isDerived()) {
                return false;
            }
            if (leftOrigin.getOriginColumnOrdinal() == rightOrigin.getOriginColumnOrdinal()) continue;
            return false;
        }
        return RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered((RelMetadataQuery)mq, (RelNode)leftRel, (ImmutableBitSet)joinInfo.leftSet());
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableLoptOptimizeJoinRule.Config.of().withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs());

        default public LoptOptimizeJoinRule toRule() {
            return new LoptOptimizeJoinRule(this);
        }
    }
}

