/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.resources;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.HashSet;
import java.util.List;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.common.Preconditions;
import org.apache.tez.runtime.common.resources.InitialMemoryAllocator;
import org.apache.tez.runtime.common.resources.InitialMemoryRequestContext;
import org.apache.tez.runtime.library.input.OrderedGroupedInputLegacy;
import org.apache.tez.runtime.library.input.OrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.UnorderedKVInput;
import org.apache.tez.runtime.library.output.OrderedPartitionedKVOutput;
import org.apache.tez.runtime.library.output.UnorderedPartitionedKVOutput;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Public
@InterfaceStability.Unstable
public class WeightedScalingMemoryDistributor
implements InitialMemoryAllocator {
    private static final Logger LOG = LoggerFactory.getLogger(WeightedScalingMemoryDistributor.class);
    static final double MAX_ADDITIONAL_RESERVATION_FRACTION_PER_IO = 0.1;
    static final double RESERVATION_FRACTION_PER_IO = 0.015;
    static final String[] DEFAULT_TASK_MEMORY_WEIGHTED_RATIOS = WeightedScalingMemoryDistributor.generateWeightStrings(1, 1, 1, 12, 12, 1, 1);
    private Configuration conf;
    private EnumMap<RequestType, Integer> typeScaleMap = Maps.newEnumMap(RequestType.class);
    private int numRequests = 0;
    private int numRequestsScaled = 0;
    private long totalRequested = 0L;
    private List<Request> requests = Lists.newArrayList();

    public Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs, int numTotalOutputs, Iterable<InitialMemoryRequestContext> initialRequests) {
        this.populateTypeScaleMap();
        for (InitialMemoryRequestContext context : initialRequests) {
            this.initialProcessMemoryRequestContext(context);
        }
        if (this.numRequestsScaled == 0) {
            this.numRequestsScaled = this.numRequests;
            for (Request request : this.requests) {
                request.requestWeight = 1;
            }
        }
        double totalScaledRequest = 0.0;
        for (Request request : this.requests) {
            double requested = (double)request.requestSize * ((double)request.requestWeight / (double)this.numRequestsScaled);
            totalScaledRequest += requested;
        }
        double reserveFraction = this.computeReservedFraction(this.numRequests);
        Preconditions.checkState((reserveFraction >= 0.0 && reserveFraction <= 1.0 ? 1 : 0) != 0);
        availableForAllocation = (long)((double)availableForAllocation - reserveFraction * (double)availableForAllocation);
        long totalJvmMem = Runtime.getRuntime().maxMemory();
        double ratio = (double)this.totalRequested / (double)totalJvmMem;
        LOG.info("Scaling Requests. NumRequests: " + this.numRequests + ", numScaledRequests: " + this.numRequestsScaled + ", TotalRequested: " + this.totalRequested + ", TotalRequestedScaled: " + totalScaledRequest + ", TotalJVMHeap: " + totalJvmMem + ", TotalAvailable: " + availableForAllocation + ", TotalRequested/TotalJVMHeap:" + new DecimalFormat("0.00").format(ratio));
        int numInputRequestsScaled = 0;
        int numOutputRequestsScaled = 0;
        long totalInputAllocated = 0L;
        long totalOutputAllocated = 0L;
        ArrayList allocations = Lists.newArrayListWithCapacity((int)this.numRequests);
        for (Request request : this.requests) {
            long allocated = 0L;
            if (request.requestSize == 0L) {
                allocations.add(0L);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Scaling requested " + request.componentClassname + " of type " + (Object)((Object)request.requestType) + " 0 to allocated: 0");
                }
            } else {
                double requestFactor = (double)request.requestWeight / (double)this.numRequestsScaled;
                double scaledRequest = requestFactor * (double)request.requestSize;
                allocated = Math.min((long)(scaledRequest / totalScaledRequest * (double)availableForAllocation), request.requestSize);
                allocations.add(allocated);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Scaling requested " + request.componentClassname + " of type " + (Object)((Object)request.requestType) + " " + request.requestSize + "  to allocated: " + allocated);
                }
            }
            if (request.componentType == InitialMemoryRequestContext.ComponentType.INPUT) {
                numInputRequestsScaled += request.requestWeight;
                totalInputAllocated += allocated;
                continue;
            }
            if (request.componentType != InitialMemoryRequestContext.ComponentType.OUTPUT) continue;
            numOutputRequestsScaled += request.requestWeight;
            totalOutputAllocated += allocated;
        }
        if (!this.conf.getBoolean("tez.task.scale.memory.input-output-concurrent", true)) {
            this.adjustAllocationsForNonConcurrent(allocations, this.requests, numInputRequestsScaled, totalInputAllocated, numOutputRequestsScaled, totalOutputAllocated);
        }
        return allocations;
    }

    private void adjustAllocationsForNonConcurrent(List<Long> allocations, List<Request> requests, int numInputsScaled, long totalInputAllocated, int numOutputsScaled, long totalOutputAllocated) {
        boolean inputsEnabled = this.conf.getBoolean("tez.task.scale.memory.non-concurrent-inputs.enabled", false);
        LOG.info("Adjusting scaled allocations for I/O non-concurrent. numInputsScaled: {} InputAllocated: {} numOutputsScaled: {} outputAllocated: {} inputsEnabled: {}", new Object[]{numInputsScaled, totalInputAllocated, numOutputsScaled, totalOutputAllocated, inputsEnabled});
        for (int i = 0; i < requests.size(); ++i) {
            Request request = requests.get(i);
            long additional = 0L;
            if (request.componentType == InitialMemoryRequestContext.ComponentType.INPUT && inputsEnabled) {
                double share = (double)request.requestWeight / (double)numInputsScaled;
                additional = (long)((double)totalOutputAllocated * share);
            } else if (request.componentType == InitialMemoryRequestContext.ComponentType.OUTPUT) {
                double share = (double)request.requestWeight / (double)numOutputsScaled;
                additional = (long)((double)totalInputAllocated * share);
            }
            if (additional <= 0L) continue;
            long newTotal = Math.min(allocations.get(i) + additional, request.requestSize);
            allocations.set(i, newTotal);
            LOG.debug("Adding {} to {} total={}", new Object[]{additional, request.componentClassname, newTotal});
        }
    }

    private void initialProcessMemoryRequestContext(InitialMemoryRequestContext context) {
        ++this.numRequests;
        this.totalRequested += context.getRequestedSize();
        String className = context.getComponentClassName();
        RequestType requestType = this.getRequestTypeForClass(className);
        Integer typeScaleFactor = this.getScaleFactorForType(requestType);
        InitialMemoryRequestContext.ComponentType componentType = context.getComponentType();
        Request request = new Request(context.getComponentClassName(), componentType, context.getRequestedSize(), requestType, typeScaleFactor);
        this.requests.add(request);
        this.numRequestsScaled += typeScaleFactor.intValue();
    }

    private Integer getScaleFactorForType(RequestType requestType) {
        Integer typeScaleFactor = this.typeScaleMap.get((Object)requestType);
        if (typeScaleFactor == null) {
            LOG.warn("Bad scale factor for requestType: " + (Object)((Object)requestType) + ", Using factor 0");
            typeScaleFactor = 0;
        }
        return typeScaleFactor;
    }

    private RequestType getRequestTypeForClass(String className) {
        RequestType requestType;
        if (className.equals(OrderedPartitionedKVOutput.class.getName())) {
            requestType = RequestType.SORTED_OUTPUT;
        } else if (className.equals(OrderedGroupedKVInput.class.getName()) || className.equals(OrderedGroupedInputLegacy.class.getName())) {
            requestType = RequestType.SORTED_MERGED_INPUT;
        } else if (className.equals(UnorderedKVInput.class.getName())) {
            requestType = RequestType.UNSORTED_INPUT;
        } else if (className.equals(UnorderedPartitionedKVOutput.class.getName())) {
            requestType = RequestType.PARTITIONED_UNSORTED_OUTPUT;
        } else {
            requestType = RequestType.OTHER;
            LOG.debug("Falling back to RequestType.OTHER for class: {}", (Object)className);
        }
        return requestType;
    }

    private void populateTypeScaleMap() {
        String[] ratios = this.conf.getStrings("tez.task.scale.memory.ratios", DEFAULT_TASK_MEMORY_WEIGHTED_RATIOS);
        int numExpectedValues = RequestType.values().length;
        if (ratios == null) {
            LOG.info("No ratio specified. Falling back to Linear scaling");
            ratios = new String[numExpectedValues];
            int i = 0;
            for (RequestType requestType : RequestType.values()) {
                ratios[i] = requestType.name() + ":1";
                ++i;
            }
        } else if (ratios.length != RequestType.values().length) {
            throw new IllegalArgumentException("Number of entries in the configured ratios should be equal to the number of entries in RequestType: " + numExpectedValues);
        }
        StringBuilder sb = new StringBuilder();
        HashSet<RequestType> seenTypes = new HashSet<RequestType>();
        for (String ratio : ratios) {
            String[] parts = ratio.split(":");
            Preconditions.checkState((parts.length == 2 ? 1 : 0) != 0);
            RequestType requestType = RequestType.valueOf(parts[0]);
            Integer ratioVal = Integer.parseInt(parts[1]);
            if (!seenTypes.add(requestType)) {
                throw new IllegalArgumentException("Cannot configure the same RequestType: " + (Object)((Object)requestType) + " multiple times");
            }
            Preconditions.checkState((ratioVal >= 0 ? 1 : 0) != 0, (Object)"Ratio must be >= 0");
            this.typeScaleMap.put(requestType, ratioVal);
            sb.append("[").append((Object)requestType).append(":").append(ratioVal).append("]");
        }
        LOG.info("ScaleRatiosUsed=" + sb.toString());
    }

    private double computeReservedFraction(int numTotalRequests) {
        double additionalReserveFraction;
        double initialReserveFraction;
        double reserveFraction;
        double reserveFractionPerIo = this.conf.getDouble("tez.task.scale.memory.additional-reservation.fraction.per-io", 0.015);
        double maxAdditionalReserveFraction = this.conf.getDouble("tez.task.scale.memory.additional-reservation.fraction.max", 0.1);
        Preconditions.checkArgument((maxAdditionalReserveFraction >= 0.0 && maxAdditionalReserveFraction <= 1.0 ? 1 : 0) != 0);
        Preconditions.checkArgument((reserveFractionPerIo <= maxAdditionalReserveFraction && reserveFractionPerIo >= 0.0 ? 1 : 0) != 0);
        if (LOG.isDebugEnabled()) {
            LOG.debug("ReservationFractionPerIO=" + reserveFractionPerIo + ", MaxPerIOReserveFraction=" + maxAdditionalReserveFraction);
        }
        Preconditions.checkState(((reserveFraction = (initialReserveFraction = this.conf.getDouble("tez.task.scale.memory.reserve-fraction", 0.3)) + (additionalReserveFraction = Math.min(maxAdditionalReserveFraction, (double)numTotalRequests * reserveFractionPerIo))) <= 1.0 ? 1 : 0) != 0);
        LOG.info("InitialReservationFraction=" + initialReserveFraction + ", AdditionalReservationFractionForIOs=" + additionalReserveFraction + ", finalReserveFractionUsed=" + reserveFraction);
        return reserveFraction;
    }

    public static String[] generateWeightStrings(int unsortedPartitioned, int unsorted, int broadcastIn, int sortedOut, int scatterGatherShuffleIn, int proc, int other) {
        String[] weights = new String[RequestType.values().length];
        weights[0] = RequestType.PARTITIONED_UNSORTED_OUTPUT.name() + ":" + unsortedPartitioned;
        weights[1] = RequestType.UNSORTED_OUTPUT.name() + ":" + unsorted;
        weights[2] = RequestType.UNSORTED_INPUT.name() + ":" + broadcastIn;
        weights[3] = RequestType.SORTED_OUTPUT.name() + ":" + sortedOut;
        weights[4] = RequestType.SORTED_MERGED_INPUT.name() + ":" + scatterGatherShuffleIn;
        weights[5] = RequestType.PROCESSOR.name() + ":" + proc;
        weights[6] = RequestType.OTHER.name() + ":" + other;
        return weights;
    }

    public void setConf(Configuration conf) {
        this.conf = conf;
    }

    public Configuration getConf() {
        return this.conf;
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    public static enum RequestType {
        PARTITIONED_UNSORTED_OUTPUT,
        UNSORTED_INPUT,
        UNSORTED_OUTPUT,
        SORTED_OUTPUT,
        SORTED_MERGED_INPUT,
        PROCESSOR,
        OTHER;

    }

    private static class Request {
        String componentClassname;
        InitialMemoryRequestContext.ComponentType componentType;
        long requestSize;
        private RequestType requestType;
        private int requestWeight;

        Request(String componentClassname, InitialMemoryRequestContext.ComponentType componentType, long requestSize, RequestType requestType, int requestWeight) {
            this.componentClassname = componentClassname;
            this.componentType = componentType;
            this.requestSize = requestSize;
            this.requestType = requestType;
            this.requestWeight = requestWeight;
        }
    }
}

