/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveJoinSwapConstraintsRule
extends RelOptRule {
    protected static final Logger LOG = LoggerFactory.getLogger(HiveJoinSwapConstraintsRule.class);
    public static final HiveJoinSwapConstraintsRule INSTANCE = new HiveJoinSwapConstraintsRule(HiveRelFactories.HIVE_BUILDER);

    protected HiveJoinSwapConstraintsRule(RelBuilderFactory relBuilder) {
        super(HiveJoinSwapConstraintsRule.operand(Join.class, (RelOptRuleOperand)HiveJoinSwapConstraintsRule.operand(Join.class, (RelOptRuleOperandChildren)HiveJoinSwapConstraintsRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveJoinSwapConstraintsRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveJoinSwapConstraintsRule.any())}), relBuilder, "HiveJoinSwapConstraintsRule");
    }

    public void onMatch(RelOptRuleCall call) {
        int i2;
        Join topJoin = (Join)call.rel(0);
        Join bottomJoin = (Join)call.rel(1);
        RexBuilder rexBuilder = topJoin.getCluster().getRexBuilder();
        if (topJoin.getJoinType().generatesNullsOnLeft() || bottomJoin.getJoinType().generatesNullsOnLeft() || bottomJoin.isSemiJoin()) {
            return;
        }
        HiveRelOptUtil.RewritablePKFKJoinInfo topInfo = HiveRelOptUtil.isRewritablePKFKJoin(topJoin, topJoin.getLeft(), topJoin.getRight(), call.getMetadataQuery());
        HiveRelOptUtil.RewritablePKFKJoinInfo bottomInfo = HiveRelOptUtil.isRewritablePKFKJoin(bottomJoin, bottomJoin.getLeft(), bottomJoin.getRight(), call.getMetadataQuery());
        if (topInfo.rewritable || !bottomInfo.rewritable) {
            return;
        }
        int nFieldsX = bottomJoin.getLeft().getRowType().getFieldList().size();
        int nFieldsY = bottomJoin.getRight().getRowType().getFieldList().size();
        int nFieldsZ = topJoin.getRight().getRowType().getFieldList().size();
        int nTotalFields = nFieldsX + nFieldsY + nFieldsZ;
        ArrayList fields = new ArrayList();
        List joinFields = topJoin.getRowType().getFieldList();
        for (i2 = 0; i2 < nFieldsX + nFieldsY; ++i2) {
            fields.add(joinFields.get(i2));
        }
        joinFields = topJoin.getRight().getRowType().getFieldList();
        for (i2 = 0; i2 < nFieldsZ; ++i2) {
            fields.add(joinFields.get(i2));
        }
        Set<Integer> leftKeys = HiveCalciteUtil.getInputRefs(topJoin.getCondition());
        leftKeys.removeIf(i -> i >= topJoin.getLeft().getRowType().getFieldCount());
        int nKeysFromX = 0;
        for (int leftKey : leftKeys) {
            if (leftKey >= nFieldsX) continue;
            ++nKeysFromX;
        }
        if (nKeysFromX != leftKeys.size()) {
            return;
        }
        int[] adjustments = new int[nTotalFields];
        this.setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, nFieldsZ, -nFieldsY);
        RexNode newBottomCondition = (RexNode)topJoin.getCondition().accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, fields, adjustments));
        Join newBottomJoin = topJoin.copy(topJoin.getTraitSet(), newBottomCondition, bottomJoin.getLeft(), topJoin.getRight(), topJoin.getJoinType(), topJoin.isSemiJoinDone());
        RexNode newTopCondition = newBottomJoin.isSemiJoin() ? bottomJoin.getCondition() : (RexNode)bottomJoin.getCondition().accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, fields, adjustments));
        Join newTopJoin = bottomJoin.copy(bottomJoin.getTraitSet(), newTopCondition, (RelNode)newBottomJoin, bottomJoin.getRight(), bottomJoin.getJoinType(), bottomJoin.isSemiJoinDone());
        if (newBottomJoin.isSemiJoin()) {
            call.transformTo((RelNode)newTopJoin);
        } else {
            int i3;
            ArrayList<RexInputRef> exprs = new ArrayList<RexInputRef>();
            for (i3 = 0; i3 < nFieldsX; ++i3) {
                exprs.add(rexBuilder.makeInputRef((RelNode)newTopJoin, i3));
            }
            for (i3 = nFieldsX + nFieldsZ; i3 < topJoin.getRowType().getFieldCount(); ++i3) {
                exprs.add(rexBuilder.makeInputRef((RelNode)newTopJoin, i3));
            }
            for (i3 = nFieldsX; i3 < nFieldsX + nFieldsZ; ++i3) {
                exprs.add(rexBuilder.makeInputRef((RelNode)newTopJoin, i3));
            }
            call.transformTo(call.builder().push((RelNode)newTopJoin).project(exprs).build());
        }
    }

    private void setJoinAdjustments(int[] adjustments, int nFieldsX, int nFieldsY, int nFieldsZ, int adjustY, int adjustZ) {
        int i;
        for (i = 0; i < nFieldsX; ++i) {
            adjustments[i] = 0;
        }
        for (i = nFieldsX; i < nFieldsX + nFieldsY; ++i) {
            adjustments[i] = adjustY;
        }
        for (i = nFieldsX + nFieldsY; i < nFieldsX + nFieldsY + nFieldsZ; ++i) {
            adjustments[i] = adjustZ;
        }
    }
}

