/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.yarn.server.globalpolicygenerator.policygenerator;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections4.MapUtils;
import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.exceptions.YarnRuntimeException;
import org.apache.hadoop.yarn.server.federation.policies.manager.FederationPolicyManager;
import org.apache.hadoop.yarn.server.federation.policies.manager.WeightedLocalityPolicyManager;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterIdInfo;
import org.apache.hadoop.yarn.server.globalpolicygenerator.GPGUtils;
import org.apache.hadoop.yarn.server.globalpolicygenerator.policygenerator.GlobalPolicy;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.dao.ClusterMetricsInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LoadBasedGlobalPolicy
extends GlobalPolicy {
    private static final Logger LOG = LoggerFactory.getLogger(LoadBasedGlobalPolicy.class);
    private int minPending;
    private int maxPending;
    private float minWeight;
    private int maxEdit;
    private Scaling scaling = Scaling.NONE;

    @Override
    public void setConf(Configuration conf) {
        super.setConf(conf);
        this.minPending = conf.getInt("yarn.federation.gpg.policy.generator.load-based.pending.minimum", 100);
        this.maxPending = conf.getInt("yarn.federation.gpg.policy.generator.load-based.pending.maximum", 1000);
        this.minWeight = conf.getFloat("yarn.federation.gpg.policy.generator.load-based.weight.minimum", 0.0f);
        this.maxEdit = conf.getInt("yarn.federation.gpg.policy.generator.load-based.edit.maximum", 3);
        try {
            this.scaling = Scaling.valueOf(conf.get("yarn.federation.gpg.policy.generator.load-based.scaling", "LINEAR"));
        }
        catch (IllegalArgumentException e) {
            LOG.warn("Invalid scaling mode provided", (Throwable)e);
        }
        if (this.minPending > this.maxPending) {
            throw new YarnRuntimeException("minPending = " + this.minPending + " must be less than or equal to maxPending=" + this.maxPending);
        }
        if (!(this.minWeight >= 0.0f) || !(this.minWeight < 1.0f)) {
            throw new YarnRuntimeException("minWeight = " + this.minWeight + " must be within range [0,1)");
        }
    }

    @Override
    protected Map<Class<?>, String> registerPaths() {
        HashMap map = new HashMap();
        map.put(ClusterMetricsInfo.class, "/metrics");
        return map;
    }

    @Override
    protected FederationPolicyManager updatePolicy(String queueName, Map<SubClusterId, Map<Class, Object>> clusterInfo, FederationPolicyManager currentManager) {
        if (currentManager == null) {
            LOG.info("Creating load based weighted policy queue {}.", (Object)queueName);
            currentManager = this.getWeightedLocalityPolicyManager(queueName, clusterInfo);
        } else if (currentManager instanceof WeightedLocalityPolicyManager) {
            LOG.info("Updating load based weighted policy queue {}.", (Object)queueName);
            currentManager = this.getWeightedLocalityPolicyManager(queueName, clusterInfo);
        } else {
            LOG.warn("Policy for queue {} is of type {}, expected {}.", new Object[]{queueName, currentManager.getClass(), WeightedLocalityPolicyManager.class});
        }
        return currentManager;
    }

    protected WeightedLocalityPolicyManager getWeightedLocalityPolicyManager(String queue, Map<SubClusterId, Map<Class, Object>> subClusterMetricInfos) {
        Map<SubClusterId, ClusterMetricsInfo> clusterMetrics = this.getSubClustersMetricsInfo(subClusterMetricInfos);
        if (MapUtils.isEmpty(clusterMetrics)) {
            return null;
        }
        WeightedLocalityPolicyManager manager = new WeightedLocalityPolicyManager();
        Map<SubClusterIdInfo, Float> weights = this.getTargetWeights(clusterMetrics);
        manager.setQueue(queue);
        manager.getWeightedPolicyInfo().setAMRMPolicyWeights(weights);
        manager.getWeightedPolicyInfo().setRouterPolicyWeights(weights);
        return manager;
    }

    protected Map<SubClusterId, ClusterMetricsInfo> getSubClustersMetricsInfo(Map<SubClusterId, Map<Class, Object>> subClusterMetricsInfo) {
        if (MapUtils.isEmpty(subClusterMetricsInfo)) {
            LOG.warn("The metric info of the subCluster is empty.");
            return null;
        }
        HashMap<SubClusterId, ClusterMetricsInfo> clusterMetrics = new HashMap<SubClusterId, ClusterMetricsInfo>();
        for (Map.Entry<SubClusterId, Map<Class, Object>> entry : subClusterMetricsInfo.entrySet()) {
            SubClusterId subClusterId = entry.getKey();
            Map<Class, Object> subClusterMetrics = entry.getValue();
            ClusterMetricsInfo clusterMetricsInfo = subClusterMetrics.getOrDefault(ClusterMetricsInfo.class, null);
            clusterMetrics.put(subClusterId, clusterMetricsInfo);
        }
        return clusterMetrics;
    }

    @VisibleForTesting
    protected Map<SubClusterIdInfo, Float> getTargetWeights(Map<SubClusterId, ClusterMetricsInfo> clusterMetrics) {
        Map<SubClusterIdInfo, Float> weights = GPGUtils.createUniformWeights(clusterMetrics.keySet());
        List<Object> scs = new ArrayList<SubClusterId>(clusterMetrics.keySet());
        scs.sort(new SortByDescendingLoad(clusterMetrics));
        scs = scs.subList(0, Math.min(this.maxEdit, scs.size()));
        for (SubClusterId subClusterId : scs) {
            LOG.info("Updating weight for sub cluster {}", (Object)subClusterId.toString());
            int pending = clusterMetrics.get(subClusterId).getAppsPending();
            if (pending <= this.minPending) {
                LOG.info("Load ({}) is lower than minimum ({}), skipping", (Object)pending, (Object)this.minPending);
                continue;
            }
            if (pending < this.maxPending) {
                int val = pending - this.minPending;
                int maxVal = this.maxPending - this.minPending;
                float weight = this.getWeightByScaling(maxVal, val);
                weight *= 1.0f - this.minWeight;
                weights.put(new SubClusterIdInfo(subClusterId), Float.valueOf(weight += this.minWeight));
                LOG.info("Load ({}) is within maximum ({}), setting weights via {} scale to {}", new Object[]{pending, this.maxPending, this.scaling, Float.valueOf(weight)});
                continue;
            }
            weights.put(new SubClusterIdInfo(subClusterId), Float.valueOf(this.minWeight));
            LOG.info("Load ({}) exceeded maximum ({}), setting weight to minimum: {}", new Object[]{pending, this.maxPending, Float.valueOf(this.minWeight)});
        }
        this.validateWeights(weights);
        return weights;
    }

    protected float getWeightByScaling(int maxPendingVal, int curPendingVal) {
        float weight = 1.0f;
        switch (this.scaling) {
            case NONE: {
                break;
            }
            case LINEAR: {
                weight = (float)(maxPendingVal - curPendingVal) / (float)maxPendingVal;
                break;
            }
            case QUADRATIC: {
                double maxValQuad = Math.pow(maxPendingVal, 2.0);
                double valQuad = Math.pow(curPendingVal, 2.0);
                weight = (float)(maxValQuad - valQuad) / (float)maxValQuad;
                break;
            }
            case LOG: {
                double maxValLog = Math.log(maxPendingVal);
                double valLog = Math.log(curPendingVal);
                weight = (float)(maxValLog - valLog) / (float)maxValLog;
                break;
            }
            default: {
                LOG.warn("No suitable scaling found, Skip.");
            }
        }
        return weight;
    }

    private void validateWeights(Map<SubClusterIdInfo, Float> weights) {
        for (Float w : weights.values()) {
            if (!(w.floatValue() > 0.0f)) continue;
            return;
        }
        LOG.warn("All {} generated weights were 0.0f. Resetting to 1.0f.", (Object)weights.size());
        weights.replaceAll((i, v) -> Float.valueOf(1.0f));
    }

    public static enum Scaling {
        LINEAR,
        QUADRATIC,
        LOG,
        NONE;

    }

    private static final class SortByDescendingLoad
    implements Comparator<SubClusterId> {
        private Map<SubClusterId, ClusterMetricsInfo> clusterMetrics;

        private SortByDescendingLoad(Map<SubClusterId, ClusterMetricsInfo> clusterMetrics) {
            this.clusterMetrics = clusterMetrics;
        }

        @Override
        public int compare(SubClusterId a, SubClusterId b) {
            return this.clusterMetrics.get(b).getAppsPending() - this.clusterMetrics.get(a).getAppsPending();
        }
    }
}

