/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableSet;
import org.apache.hadoop.util.Shell;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NvidiaGPUPluginForRuntimeV2
implements DevicePlugin,
DevicePluginScheduler {
    public static final Logger LOG = LoggerFactory.getLogger(NvidiaGPUPluginForRuntimeV2.class);
    public static final String NV_RESOURCE_NAME = "nvidia.com/gpu";
    private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor();
    private Map<String, String> environment = new HashMap<String, String>();
    private static final String ENV_BINARY_PATH = "NVIDIA_SMI_PATH";
    private static final String DEFAULT_BINARY_NAME = "nvidia-smi";
    private static final String DEV_NAME_PREFIX = "nvidia";
    private String pathOfGpuBinary = null;
    private static final int MAX_EXEC_TIMEOUT_MS = 10000;
    private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of((Object)"/usr/bin", (Object)"/bin", (Object)"/usr/local/nvidia/bin");
    private boolean topoInitialized = false;
    private Set<Device> lastTimeFoundDevices;
    private Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable = new HashMap<Integer, List<Map.Entry<Set<Device>, Integer>>>();
    private Map<String, Integer> devicePairToWeight = new HashMap<String, Integer>();
    public static final String TOPOLOGY_POLICY_ENV_KEY = "NVIDIA_TOPO_POLICY";
    public static final String TOPOLOGY_POLICY_PACK = "PACK";
    public static final String TOPOLOGY_POLICY_SPREAD = "SPREAD";

    @Override
    public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
        return DeviceRegisterRequest.Builder.newInstance().setResourceName(NV_RESOURCE_NAME).build();
    }

    @Override
    public Set<Device> getDevices() throws Exception {
        this.shellExecutor.searchBinary();
        TreeSet<Device> r = new TreeSet<Device>();
        try {
            String output = this.shellExecutor.getDeviceInfo();
            String[] lines = output.trim().split("\n");
            int id = 0;
            for (String oneLine : lines) {
                String[] tokensEachLine = oneLine.split(",");
                if (tokensEachLine.length != 2) {
                    throw new Exception("Cannot parse the output to get device info. Unexpected format in it:" + oneLine);
                }
                String minorNumber = tokensEachLine[0].trim();
                String busId = tokensEachLine[1].trim();
                String majorNumber = this.getMajorNumber(DEV_NAME_PREFIX + minorNumber);
                if (majorNumber == null) continue;
                r.add(Device.Builder.newInstance().setId(id).setMajorNumber(Integer.parseInt(majorNumber)).setMinorNumber(Integer.parseInt(minorNumber)).setBusID(busId).setDevPath("/dev/nvidia" + minorNumber).setHealthy(true).build());
                ++id;
            }
            this.lastTimeFoundDevices = r;
            return r;
        }
        catch (IOException e) {
            LOG.debug("Failed to get output from {}", (Object)this.pathOfGpuBinary);
            throw new YarnException((Throwable)e);
        }
    }

    @Override
    public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices, YarnRuntimeType yarnRuntime) throws Exception {
        LOG.debug("Generating runtime spec for allocated devices: {}, {}", allocatedDevices, (Object)yarnRuntime.getName());
        if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
            String nvidiaRuntime = DEV_NAME_PREFIX;
            String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
            StringBuilder gpuMinorNumbersSB = new StringBuilder();
            for (Device device : allocatedDevices) {
                gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
            }
            String minorNumbers = gpuMinorNumbersSB.toString();
            LOG.info("Nvidia Docker v2 assigned GPU: " + minorNumbers);
            return DeviceRuntimeSpec.Builder.newInstance().addEnv(nvidiaVisibleDevices, minorNumbers.substring(0, minorNumbers.length() - 1)).setContainerRuntime(nvidiaRuntime).build();
        }
        return null;
    }

    @Override
    public void onDevicesReleased(Set<Device> releasedDevices) throws Exception {
    }

    private String getMajorNumber(String devName) {
        String output = null;
        try {
            LOG.debug("Get major numbers from /dev/{}", (Object)devName);
            output = this.shellExecutor.getMajorMinorInfo(devName);
            String[] strs = output.trim().split(":");
            LOG.debug("stat output:{}", (Object)output);
            output = Integer.toString(Integer.parseInt(strs[0], 16));
        }
        catch (IOException e) {
            String msg = "Failed to get major number from reading /dev/" + devName;
            LOG.warn(msg);
        }
        catch (NumberFormatException e) {
            LOG.error("Failed to parse device major number from stat output");
            output = null;
        }
        return output;
    }

    @Override
    public Set<Device> allocateDevices(Set<Device> availableDevices, int count, Map<String, String> envs) {
        TreeSet<Device> allocation = new TreeSet<Device>();
        if (availableDevices.size() < 3 || count == 1 || availableDevices.size() == count) {
            this.basicSchedule(allocation, count, availableDevices);
            return allocation;
        }
        try {
            if (!this.topoInitialized) {
                this.initCostTable();
            }
            this.topologyAwareSchedule(allocation, count, envs, availableDevices, this.costTable);
            if (allocation.size() == count) {
                return allocation;
            }
            LOG.error("Failed to do topology scheduling. Skip to use basic scheduling");
        }
        catch (IOException e) {
            LOG.error("Error in getting GPU topology info. Skip topology aware scheduling", (Throwable)e);
        }
        this.basicSchedule(allocation, count, availableDevices);
        return allocation;
    }

    @VisibleForTesting
    public void initCostTable() throws IOException {
        String topo = this.shellExecutor.getTopologyInfo();
        this.parseTopo(topo, this.devicePairToWeight);
        if (this.lastTimeFoundDevices == null) {
            try {
                this.getDevices();
            }
            catch (Exception e) {
                LOG.error("Failed to get devices!", (Throwable)e);
                return;
            }
        }
        this.buildCostTable(this.costTable, this.lastTimeFoundDevices);
        this.loggingCostTable(this.costTable);
        this.topoInitialized = true;
    }

    private void loggingCostTable(Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("The costTable is:");
            sb.append("\n{");
            for (Map.Entry<Integer, List<Map.Entry<Set<Device>, Integer>>> entry : cTable.entrySet()) {
                sb.append("\n\t").append(entry.getKey()).append(" => [");
                for (Map.Entry<Set<Device>, Integer> e : entry.getValue()) {
                    sb.append("\n\t\t").append(e.toString()).append(",\n");
                }
                sb.append("\t\t]\n");
            }
            sb.append("}\n");
            LOG.debug(sb.toString());
        }
    }

    private void buildCostTable(Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable, Set<Device> ltfDevices) {
        Device[] deviceList = new Device[ltfDevices.size()];
        ltfDevices.toArray(deviceList);
        this.generateAllDeviceCombination(cTable, deviceList, deviceList.length);
    }

    private void generateAllDeviceCombination(Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable, Device[] allDevices, int n) {
        for (int i = 2; i < n; ++i) {
            HashMap<Set<Device>, Integer> combinationToCost = new HashMap<Set<Device>, Integer>();
            this.buildCombination(combinationToCost, allDevices, n, i);
            LinkedList listSortedByCost = new LinkedList(combinationToCost.entrySet());
            Collections.sort(listSortedByCost, (o1, o2) -> ((Integer)o1.getValue()).compareTo((Integer)o2.getValue()));
            cTable.put(i, listSortedByCost);
        }
    }

    private void buildCombination(Map<Set<Device>, Integer> combinationToCost, Device[] allDevices, int n, int r) {
        Device[] subDeviceList = new Device[r];
        this.combinationRecursive(combinationToCost, allDevices, subDeviceList, 0, n - 1, 0, r);
    }

    void combinationRecursive(Map<Set<Device>, Integer> cTc, Device[] allDevices, Device[] subDeviceList, int start, int end, int index, int r) {
        if (index == r) {
            TreeSet<Device> oneSet = new TreeSet<Device>(Arrays.asList(subDeviceList));
            int cost = this.computeCostOfDevices(subDeviceList);
            cTc.put(oneSet, cost);
            return;
        }
        for (int i = start; i <= end; ++i) {
            subDeviceList[index] = allDevices[i];
            this.combinationRecursive(cTc, allDevices, subDeviceList, i + 1, end, index + 1, r);
        }
    }

    @VisibleForTesting
    public int computeCostOfDevices(Device[] devices) {
        int cost = 0;
        for (int i = 0; i < devices.length; ++i) {
            String gpuIndex0 = String.valueOf(devices[i].getMinorNumber());
            for (int j = i + 1; j < devices.length; ++j) {
                String gpuIndex1 = String.valueOf(devices[j].getMinorNumber());
                cost += this.devicePairToWeight.get(gpuIndex0 + "-" + gpuIndex1).intValue();
            }
        }
        return cost;
    }

    @VisibleForTesting
    public void topologyAwareSchedule(Set<Device> allocation, int count, Map<String, String> envs, Set<Device> availableDevices, Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
        boolean num = false;
        String policy = envs.get(TOPOLOGY_POLICY_ENV_KEY);
        if (policy == null) {
            policy = TOPOLOGY_POLICY_PACK;
        }
        if (cTable == null) {
            LOG.error("No cost table initialized!");
            return;
        }
        List<Map.Entry<Set<Device>, Integer>> combinationsToCost = cTable.get(count);
        Iterator<Map.Entry<Set<Device>, Integer>> iterator = combinationsToCost.iterator();
        if (policy.equalsIgnoreCase(TOPOLOGY_POLICY_SPREAD)) {
            iterator = ((LinkedList)combinationsToCost).descendingIterator();
        }
        while (iterator.hasNext()) {
            Map.Entry<Set<Device>, Integer> element = iterator.next();
            if (!availableDevices.containsAll((Collection)element.getKey())) continue;
            allocation.addAll((Collection<Device>)element.getKey());
            LOG.info("Topology scheduler allocated: " + allocation);
            return;
        }
        LOG.error("Unknown error happened in topology scheduler");
    }

    @VisibleForTesting
    public void basicSchedule(Set<Device> allocation, int count, Set<Device> availableDevices) {
        if (count == availableDevices.size()) {
            allocation.addAll(availableDevices);
            return;
        }
        int number = 0;
        for (Device d : availableDevices) {
            allocation.add(d);
            if (++number != count) continue;
            break;
        }
    }

    public void parseTopo(String topo, Map<String, Integer> deviceLinkToWeight) {
        String[] lines;
        for (String oneLine : lines = topo.split("\n")) {
            if ((oneLine = oneLine.trim()).isEmpty()) continue;
            if (oneLine.startsWith("Legend")) break;
            if (oneLine.contains("Affinity")) continue;
            String[] tokens = oneLine.split("\\s+");
            String name = tokens[0];
            int rowMinor = Integer.parseInt(name.substring(name.lastIndexOf("U") + 1));
            for (int i = 1; i < tokens.length; ++i) {
                String tempType = tokens[i];
                int colMinor = i - 1;
                if (tempType.equals("X")) continue;
                if (tempType.equals("SOC") || tempType.equals("SYS")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkCrossCPUSocket, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("PHB") || tempType.equals("NODE")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkSameCPUSocket, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("PXB")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkMultiSwitch, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("PIX")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkSingleSwitch, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV1")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink1, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV2")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink2, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV3")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink3, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV4")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink4, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV5")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink5, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV6")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink6, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV7")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink7, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (tempType.equals("NV8")) {
                    this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink8, rowMinor, colMinor, deviceLinkToWeight);
                    continue;
                }
                if (!tempType.equals("NV9")) continue;
                this.populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink9, rowMinor, colMinor, deviceLinkToWeight);
            }
        }
    }

    private void populateGraphEdgeWeight(DeviceLinkType linkType, int leftVertex, int rightVertex, Map<String, Integer> deviceLinkToWeight) {
        deviceLinkToWeight.put(leftVertex + "-" + rightVertex, linkType.getWeight());
    }

    @VisibleForTesting
    public void setPathOfGpuBinary(String pOfGpuBinary) {
        this.pathOfGpuBinary = pOfGpuBinary;
    }

    @VisibleForTesting
    public void setShellExecutor(NvidiaCommandExecutor shellExecutor) {
        this.shellExecutor = shellExecutor;
    }

    @VisibleForTesting
    public boolean isTopoInitialized() {
        return this.topoInitialized;
    }

    @VisibleForTesting
    public Map<Integer, List<Map.Entry<Set<Device>, Integer>>> getCostTable() {
        return this.costTable;
    }

    @VisibleForTesting
    public Map<String, Integer> getDevicePairToWeight() {
        return this.devicePairToWeight;
    }

    public class NvidiaCommandExecutor {
        public String getDeviceInfo() throws IOException {
            return Shell.execCommand(NvidiaGPUPluginForRuntimeV2.this.environment, (String[])new String[]{NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary, "--query-gpu=index,pci.bus_id", "--format=csv,noheader"}, (long)10000L);
        }

        public String getMajorMinorInfo(String devName) throws IOException {
            Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(new String[]{"stat", "-c", "%t:%T", "/dev/" + devName});
            shexec.execute();
            return shexec.getOutput();
        }

        public String getTopologyInfo() throws IOException {
            return Shell.execCommand(NvidiaGPUPluginForRuntimeV2.this.environment, (String[])new String[]{NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary, "topo", "-m"}, (long)10000L);
        }

        public void searchBinary() throws Exception {
            if (NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary != null) {
                LOG.info("Skip searching, the nvidia gpu binary is already set: " + NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary);
                return;
            }
            String envBinaryPath = System.getenv(NvidiaGPUPluginForRuntimeV2.ENV_BINARY_PATH);
            if (null != envBinaryPath && new File(envBinaryPath).exists()) {
                NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary = envBinaryPath;
                LOG.info("Use nvidia gpu binary: " + NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary);
                return;
            }
            LOG.info("Search binary..");
            boolean found = false;
            for (String dir : DEFAULT_BINARY_SEARCH_DIRS) {
                File binaryFile = new File(dir, NvidiaGPUPluginForRuntimeV2.DEFAULT_BINARY_NAME);
                if (!binaryFile.exists()) continue;
                found = true;
                NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary = binaryFile.getAbsolutePath();
                LOG.info("Found binary:" + NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary);
                break;
            }
            if (!found) {
                LOG.error("No binary found from env variable: NVIDIA_SMI_PATH or path " + DEFAULT_BINARY_SEARCH_DIRS.toString());
                throw new Exception("No binary found for " + NvidiaGPUPluginForRuntimeV2.class);
            }
        }
    }

    public static enum DeviceLinkType {
        P2PLinkNVLink9(10),
        P2PLinkNVLink8(20),
        P2PLinkNVLink7(30),
        P2PLinkNVLink6(40),
        P2PLinkNVLink5(50),
        P2PLinkNVLink4(60),
        P2PLinkNVLink3(70),
        P2PLinkNVLink2(80),
        P2PLinkNVLink1(90),
        P2PLinkSameCPUSocket(200),
        P2PLinkCrossCPUSocket(300),
        P2PLinkSingleSwitch(600),
        P2PLinkMultiSwitch(1200);

        private int weight;

        public int getWeight() {
            return this.weight;
        }

        private DeviceLinkType(int w) {
            this.weight = w;
        }
    }
}

