/*
 * Decompiled with CFR 0.152.
 */
package org.apache.impala.planner;

import com.google.common.base.Preconditions;
import com.google.common.math.LongMath;
import java.math.RoundingMode;
import java.util.List;
import java.util.function.Supplier;
import org.apache.impala.planner.BaseProcessingCost;
import org.apache.impala.planner.BroadcastProcessingCost;
import org.apache.impala.planner.ScaledProcessingCost;
import org.apache.impala.planner.SumProcessingCost;
import org.apache.impala.service.BackendConfig;

public abstract class ProcessingCost
implements Cloneable {
    protected Supplier<Integer> numInstanceSupplier_ = null;
    private long numRowToProduce_ = 0L;
    private long numRowToConsume_ = 0L;
    private String label_ = null;
    private boolean isSetNumRowToProduce_ = false;
    private boolean isSetNumRowToConsume_ = false;

    public static ProcessingCost invalid() {
        return new BaseProcessingCost(-1L, 1.0f, 0.0f);
    }

    public static ProcessingCost zero() {
        return new BaseProcessingCost(0L, 1.0f, 0.0f);
    }

    public static ProcessingCost maxCost(ProcessingCost a, ProcessingCost b) {
        return a.getTotalCost() >= b.getTotalCost() ? a : b;
    }

    public static ProcessingCost sumCost(ProcessingCost a, ProcessingCost b) {
        return new SumProcessingCost(a, b);
    }

    public static ProcessingCost scaleCost(ProcessingCost cost, long factor) {
        return new ScaledProcessingCost(cost, factor);
    }

    public static ProcessingCost broadcastCost(ProcessingCost cost, Supplier<Integer> numInstanceSupplier) {
        return new BroadcastProcessingCost(cost, numInstanceSupplier);
    }

    protected static void tryAdjustConsumerParallelism(int nodeStepCount, int minParallelism, int maxParallelism, ProcessingCost producer, ProcessingCost consumer) {
        Preconditions.checkState((consumer.getNumInstancesExpected() > 0 ? 1 : 0) != 0);
        Preconditions.checkState((producer.getNumInstancesExpected() > 0 ? 1 : 0) != 0);
        if (producer.getCostPerRowProduced() > 0.0f && (consumer.canReducedBy(nodeStepCount, minParallelism, producer) || consumer.canIncreaseBy(nodeStepCount, maxParallelism, producer))) {
            float consProdRatio = consumer.consumerProducerRatio(producer);
            int adjustedCount = (int)Math.ceil(consProdRatio * (float)producer.getNumInstancesExpected() / (float)nodeStepCount) * nodeStepCount;
            int finalCount = Math.max(minParallelism, Math.min(maxParallelism, adjustedCount));
            consumer.setNumInstanceExpected(() -> finalCount);
        } else if (maxParallelism < consumer.getNumInstancesExpected()) {
            consumer.setNumInstanceExpected(() -> maxParallelism);
        }
    }

    private static ProcessingCost computeValidBaseCost(long cardinality, float exprsCost, float materializationCost) {
        return new BaseProcessingCost(Math.max(0L, cardinality), exprsCost, materializationCost);
    }

    public static ProcessingCost basicCost(String label, long cardinality, float exprsCost, float materializationCost) {
        ProcessingCost processingCost = ProcessingCost.computeValidBaseCost(cardinality, exprsCost, materializationCost);
        processingCost.setLabel(label);
        return processingCost;
    }

    public static ProcessingCost basicCost(String label, long cardinality, float exprsCost) {
        ProcessingCost processingCost = ProcessingCost.computeValidBaseCost(cardinality, exprsCost, 0.0f);
        processingCost.setLabel(label);
        return processingCost;
    }

    public static ProcessingCost basicCost(String label, double totalCost) {
        try {
            BaseProcessingCost processingCost = new BaseProcessingCost(totalCost);
            processingCost.setLabel(label);
            return processingCost;
        }
        catch (IllegalArgumentException ex) {
            throw new IllegalArgumentException(String.format("Invalid totalCost supplied for %s", label), ex);
        }
    }

    protected static ProcessingCost fullMergeCosts(List<ProcessingCost> costs) {
        Preconditions.checkNotNull(costs);
        Preconditions.checkArgument((!costs.isEmpty() ? 1 : 0) != 0);
        ProcessingCost resultingCost = ProcessingCost.zero();
        long inputCardinality = 0L;
        long outputCardinality = 0L;
        int maxProducerParallelism = 1;
        for (ProcessingCost cost : costs) {
            resultingCost = ProcessingCost.sumCost(resultingCost, cost);
            inputCardinality += cost.getNumRowToConsume();
            outputCardinality += cost.getNumRowToProduce();
            maxProducerParallelism = Math.max(maxProducerParallelism, cost.getNumInstancesExpected());
        }
        resultingCost.setNumRowToConsume(inputCardinality);
        resultingCost.setNumRowToProduce(outputCardinality);
        int finalProducerParallelism = maxProducerParallelism;
        resultingCost.setNumInstanceExpected(() -> finalProducerParallelism);
        return resultingCost;
    }

    public abstract long getTotalCost();

    public abstract boolean isValid();

    public abstract ProcessingCost clone();

    public String getDetails() {
        StringBuilder output = new StringBuilder();
        output.append("cost-total=").append(this.getTotalCost()).append(" max-instances=").append(this.getNumInstanceMax());
        if (this.hasAdjustedInstanceCount()) {
            output.append(" adj-instances=").append(this.getNumInstancesExpected());
        }
        output.append(" cost/inst=").append(this.getPerInstanceCost()).append(" #cons:#prod=").append(this.numRowToConsume_).append(":").append(this.numRowToProduce_);
        if (this.isSetNumRowToConsume_ && this.isSetNumRowToProduce_) {
            output.append(" reduction=").append(this.getReduction());
        }
        if (this.isSetNumRowToConsume_) {
            output.append(" cost/cons=").append(this.getCostPerRowConsumed());
        }
        if (this.isSetNumRowToProduce_) {
            output.append(" cost/prod=").append(this.getCostPerRowProduced());
        }
        return output.toString();
    }

    public String debugString() {
        StringBuilder output = new StringBuilder();
        if (this.label_ != null) {
            output.append(this.label_);
            output.append("=");
        }
        output.append(this);
        return output.toString();
    }

    public String toString() {
        return "{" + this.getDetails() + "}";
    }

    public String getExplainString(String detailPrefix, boolean fullExplain) {
        return detailPrefix + this.getDetails();
    }

    public void setNumInstanceExpected(Supplier<Integer> countSupplier) {
        Preconditions.checkArgument((countSupplier.get() > 0 ? 1 : 0) != 0, (Object)"Number of instance must be greater than 0!");
        this.numInstanceSupplier_ = countSupplier;
    }

    public int getNumInstancesExpected() {
        return this.hasAdjustedInstanceCount() ? this.numInstanceSupplier_.get().intValue() : this.getNumInstanceMax();
    }

    private boolean hasAdjustedInstanceCount() {
        return this.numInstanceSupplier_ != null && this.numInstanceSupplier_.get() > 0;
    }

    protected int getNumInstanceMax() {
        return this.getNumInstanceMax(1);
    }

    protected int getNumInstanceMax(int numNodes) {
        long maxParallelism = LongMath.divide((long)this.getTotalCost(), (long)BackendConfig.INSTANCE.getMinProcessingPerThread(), (RoundingMode)RoundingMode.CEILING);
        return ProcessingCost.roundUpNumNodeMultiple(maxParallelism, numNodes);
    }

    protected static int roundUpNumNodeMultiple(long parallelism, int numNodes) {
        long maxParallelism = LongMath.divide((long)parallelism, (long)numNodes, (RoundingMode)RoundingMode.CEILING) * (long)numNodes;
        if (maxParallelism <= 0L) {
            maxParallelism = 1L;
        } else if (maxParallelism > Integer.MAX_VALUE) {
            maxParallelism = Integer.MAX_VALUE - Integer.MAX_VALUE % numNodes;
        }
        return (int)maxParallelism;
    }

    public void setNumRowToProduce(long numRowToProduce) {
        this.numRowToProduce_ = Math.max(0L, numRowToProduce);
        this.isSetNumRowToProduce_ = true;
    }

    protected void setNumRowToConsume(long numRowToConsume) {
        this.numRowToConsume_ = Math.max(0L, numRowToConsume);
        this.isSetNumRowToConsume_ = true;
    }

    public void setLabel(String label) {
        this.label_ = label;
    }

    public long getNumRowToConsume() {
        return this.numRowToConsume_;
    }

    public long getNumRowToProduce() {
        return this.numRowToProduce_;
    }

    private int getPerInstanceCost() {
        Preconditions.checkState((this.getNumInstancesExpected() > 0 ? 1 : 0) != 0);
        return (int)Math.ceil((float)this.getTotalCost() / (float)this.getNumInstancesExpected());
    }

    private float getReduction() {
        return (float)this.numRowToConsume_ / (float)Math.max(1L, this.numRowToProduce_);
    }

    private float getCostPerRowProduced() {
        return (float)this.getTotalCost() / (float)Math.max(1L, this.numRowToProduce_);
    }

    private float getCostPerRowConsumed() {
        return (float)this.getTotalCost() / (float)Math.max(1L, this.numRowToConsume_);
    }

    private float instanceRatio(ProcessingCost other) {
        Preconditions.checkState((this.getNumInstancesExpected() > 0 ? 1 : 0) != 0);
        return (float)this.getNumInstancesExpected() / (float)other.getNumInstancesExpected();
    }

    private float consumerProducerRatio(ProcessingCost other) {
        return this.getCostPerRowConsumed() / Math.max(1.0f, other.getCostPerRowProduced());
    }

    private boolean isAtLowestInstanceRatio(int nodeStepCount, int minParallelism, ProcessingCost other) {
        if (this.getNumInstancesExpected() - nodeStepCount < minParallelism) {
            return true;
        }
        float lowerRatio = (float)(this.getNumInstancesExpected() - nodeStepCount) / (float)other.getNumInstancesExpected();
        return lowerRatio < this.consumerProducerRatio(other);
    }

    private boolean isAtHighestInstanceRatio(int nodeStepCount, int maxInstance, ProcessingCost other) {
        if (this.getNumInstancesExpected() + nodeStepCount > maxInstance) {
            return true;
        }
        float higherRatio = (float)(this.getNumInstancesExpected() + nodeStepCount) / (float)other.getNumInstancesExpected();
        return higherRatio > this.consumerProducerRatio(other);
    }

    private boolean canReducedBy(int nodeStepCount, int minParallelism, ProcessingCost other) {
        return !this.isAtLowestInstanceRatio(nodeStepCount, minParallelism, other) && this.consumerProducerRatio(other) < this.instanceRatio(other);
    }

    private boolean canIncreaseBy(int nodeStepCount, int maxInstance, ProcessingCost other) {
        return !this.isAtHighestInstanceRatio(nodeStepCount, maxInstance, other) && this.consumerProducerRatio(other) > this.instanceRatio(other);
    }
}

