package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.shaded.org.glassfish.jersey.logging.LoggingFeature;
import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableSet;
import org.apache.hadoop.util.Shell;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.placement.converter.LegacyMappingRuleToJson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2.class */
public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin, DevicePluginScheduler {
    public static final String NV_RESOURCE_NAME = "nvidia.com/gpu";
    private static final String ENV_BINARY_PATH = "NVIDIA_SMI_PATH";
    private static final String DEFAULT_BINARY_NAME = "nvidia-smi";
    private static final String DEV_NAME_PREFIX = "nvidia";
    private static final int MAX_EXEC_TIMEOUT_MS = 10000;
    private Set<Device> lastTimeFoundDevices;
    public static final String TOPOLOGY_POLICY_ENV_KEY = "NVIDIA_TOPO_POLICY";
    public static final String TOPOLOGY_POLICY_PACK = "PACK";
    public static final String TOPOLOGY_POLICY_SPREAD = "SPREAD";
    public static final Logger LOG = LoggerFactory.getLogger(NvidiaGPUPluginForRuntimeV2.class);
    private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of("/usr/bin", "/bin", "/usr/local/nvidia/bin");
    private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor();
    private Map<String, String> environment = new HashMap();
    private String pathOfGpuBinary = null;
    private boolean topoInitialized = false;
    private Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable = new HashMap();
    private Map<String, Integer> devicePairToWeight = new HashMap();

    /* loaded from: input_file:org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2$DeviceLinkType.class */
    public enum DeviceLinkType {
        P2PLinkNVLink9(10),
        P2PLinkNVLink8(20),
        P2PLinkNVLink7(30),
        P2PLinkNVLink6(40),
        P2PLinkNVLink5(50),
        P2PLinkNVLink4(60),
        P2PLinkNVLink3(70),
        P2PLinkNVLink2(80),
        P2PLinkNVLink1(90),
        P2PLinkSameCPUSocket(200),
        P2PLinkCrossCPUSocket(300),
        P2PLinkSingleSwitch(600),
        P2PLinkMultiSwitch(1200);

        private int weight;

        public int getWeight() {
            return this.weight;
        }

        DeviceLinkType(int i) {
            this.weight = i;
        }
    }

    /* loaded from: input_file:org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/com/nvidia/NvidiaGPUPluginForRuntimeV2$NvidiaCommandExecutor.class */
    public class NvidiaCommandExecutor {
        public NvidiaCommandExecutor() {
        }

        public String getDeviceInfo() throws IOException {
            return Shell.execCommand(NvidiaGPUPluginForRuntimeV2.this.environment, new String[]{NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary, "--query-gpu=index,pci.bus_id", "--format=csv,noheader"}, 10000L);
        }

        public String getMajorMinorInfo(String str) throws IOException {
            Shell.ShellCommandExecutor shellCommandExecutor = new Shell.ShellCommandExecutor(new String[]{"stat", "-c", "%t:%T", "/dev/" + str});
            shellCommandExecutor.execute();
            return shellCommandExecutor.getOutput();
        }

        public String getTopologyInfo() throws IOException {
            return Shell.execCommand(NvidiaGPUPluginForRuntimeV2.this.environment, new String[]{NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary, "topo", "-m"}, 10000L);
        }

        public void searchBinary() throws Exception {
            if (NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary != null) {
                NvidiaGPUPluginForRuntimeV2.LOG.info("Skip searching, the nvidia gpu binary is already set: " + NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary);
                return;
            }
            String str = System.getenv(NvidiaGPUPluginForRuntimeV2.ENV_BINARY_PATH);
            if (null != str && new File(str).exists()) {
                NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary = str;
                NvidiaGPUPluginForRuntimeV2.LOG.info("Use nvidia gpu binary: " + NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary);
                return;
            }
            NvidiaGPUPluginForRuntimeV2.LOG.info("Search binary..");
            boolean z = false;
            Iterator<String> it = NvidiaGPUPluginForRuntimeV2.DEFAULT_BINARY_SEARCH_DIRS.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                File file = new File(it.next(), NvidiaGPUPluginForRuntimeV2.DEFAULT_BINARY_NAME);
                if (file.exists()) {
                    z = true;
                    NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary = file.getAbsolutePath();
                    NvidiaGPUPluginForRuntimeV2.LOG.info("Found binary:" + NvidiaGPUPluginForRuntimeV2.this.pathOfGpuBinary);
                    break;
                }
            }
            if (z) {
                return;
            }
            NvidiaGPUPluginForRuntimeV2.LOG.error("No binary found from env variable: NVIDIA_SMI_PATH or path " + NvidiaGPUPluginForRuntimeV2.DEFAULT_BINARY_SEARCH_DIRS.toString());
            throw new Exception("No binary found for " + NvidiaGPUPluginForRuntimeV2.class);
        }
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin
    public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
        return DeviceRegisterRequest.Builder.newInstance().setResourceName(NV_RESOURCE_NAME).build();
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin
    public Set<Device> getDevices() throws Exception {
        this.shellExecutor.searchBinary();
        TreeSet treeSet = new TreeSet();
        try {
            int i = 0;
            for (String str : this.shellExecutor.getDeviceInfo().trim().split(LoggingFeature.DEFAULT_SEPARATOR)) {
                String[] split = str.split(",");
                if (split.length != 2) {
                    throw new Exception("Cannot parse the output to get device info. Unexpected format in it:" + str);
                }
                String trim = split[0].trim();
                String trim2 = split[1].trim();
                String majorNumber = getMajorNumber("nvidia" + trim);
                if (majorNumber != null) {
                    treeSet.add(Device.Builder.newInstance().setId(i).setMajorNumber(Integer.parseInt(majorNumber)).setMinorNumber(Integer.parseInt(trim)).setBusID(trim2).setDevPath("/dev/nvidia" + trim).setHealthy(true).build());
                    i++;
                }
            }
            this.lastTimeFoundDevices = treeSet;
            return treeSet;
        } catch (IOException e) {
            LOG.debug("Failed to get output from {}", this.pathOfGpuBinary);
            throw new YarnException(e);
        }
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin
    public DeviceRuntimeSpec onDevicesAllocated(Set<Device> set, YarnRuntimeType yarnRuntimeType) throws Exception {
        LOG.debug("Generating runtime spec for allocated devices: {}, {}", set, yarnRuntimeType.getName());
        if (yarnRuntimeType != YarnRuntimeType.RUNTIME_DOCKER) {
            return null;
        }
        StringBuilder sb = new StringBuilder();
        Iterator<Device> it = set.iterator();
        while (it.hasNext()) {
            sb.append(it.next().getMinorNumber() + ",");
        }
        String sb2 = sb.toString();
        LOG.info("Nvidia Docker v2 assigned GPU: " + sb2);
        return DeviceRuntimeSpec.Builder.newInstance().addEnv("NVIDIA_VISIBLE_DEVICES", sb2.substring(0, sb2.length() - 1)).setContainerRuntime(DEV_NAME_PREFIX).build();
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin
    public void onDevicesReleased(Set<Device> set) throws Exception {
    }

    private String getMajorNumber(String str) {
        String str2 = null;
        try {
            LOG.debug("Get major numbers from /dev/{}", str);
            String majorMinorInfo = this.shellExecutor.getMajorMinorInfo(str);
            String[] split = majorMinorInfo.trim().split(LegacyMappingRuleToJson.RULE_PART_DELIMITER);
            LOG.debug("stat output:{}", majorMinorInfo);
            str2 = Integer.toString(Integer.parseInt(split[0], 16));
        } catch (IOException e) {
            LOG.warn("Failed to get major number from reading /dev/" + str);
        } catch (NumberFormatException e2) {
            LOG.error("Failed to parse device major number from stat output");
            str2 = null;
        }
        return str2;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler
    public Set<Device> allocateDevices(Set<Device> set, int i, Map<String, String> map) {
        TreeSet treeSet = new TreeSet();
        if (set.size() < 3 || i == 1 || set.size() == i) {
            basicSchedule(treeSet, i, set);
            return treeSet;
        }
        try {
            if (!this.topoInitialized) {
                initCostTable();
            }
            topologyAwareSchedule(treeSet, i, map, set, this.costTable);
        } catch (IOException e) {
            LOG.error("Error in getting GPU topology info. Skip topology aware scheduling", e);
        }
        if (treeSet.size() == i) {
            return treeSet;
        }
        LOG.error("Failed to do topology scheduling. Skip to use basic scheduling");
        basicSchedule(treeSet, i, set);
        return treeSet;
    }

    @VisibleForTesting
    public void initCostTable() throws IOException {
        parseTopo(this.shellExecutor.getTopologyInfo(), this.devicePairToWeight);
        if (this.lastTimeFoundDevices == null) {
            try {
                getDevices();
            } catch (Exception e) {
                LOG.error("Failed to get devices!", e);
                return;
            }
        }
        buildCostTable(this.costTable, this.lastTimeFoundDevices);
        loggingCostTable(this.costTable);
        this.topoInitialized = true;
    }

    private void loggingCostTable(Map<Integer, List<Map.Entry<Set<Device>, Integer>>> map) {
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("The costTable is:");
            sb.append("\n{");
            for (Map.Entry<Integer, List<Map.Entry<Set<Device>, Integer>>> entry : map.entrySet()) {
                sb.append("\n\t").append(entry.getKey()).append(" => [");
                Iterator<Map.Entry<Set<Device>, Integer>> it = entry.getValue().iterator();
                while (it.hasNext()) {
                    sb.append("\n\t\t").append(it.next().toString()).append(",\n");
                }
                sb.append("\t\t]\n");
            }
            sb.append("}\n");
            LOG.debug(sb.toString());
        }
    }

    private void buildCostTable(Map<Integer, List<Map.Entry<Set<Device>, Integer>>> map, Set<Device> set) {
        Device[] deviceArr = new Device[set.size()];
        set.toArray(deviceArr);
        generateAllDeviceCombination(map, deviceArr, deviceArr.length);
    }

    private void generateAllDeviceCombination(Map<Integer, List<Map.Entry<Set<Device>, Integer>>> map, Device[] deviceArr, int i) {
        for (int i2 = 2; i2 < i; i2++) {
            HashMap hashMap = new HashMap();
            buildCombination(hashMap, deviceArr, i, i2);
            LinkedList linkedList = new LinkedList(hashMap.entrySet());
            Collections.sort(linkedList, (entry, entry2) -> {
                return ((Integer) entry.getValue()).compareTo((Integer) entry2.getValue());
            });
            map.put(Integer.valueOf(i2), linkedList);
        }
    }

    private void buildCombination(Map<Set<Device>, Integer> map, Device[] deviceArr, int i, int i2) {
        combinationRecursive(map, deviceArr, new Device[i2], 0, i - 1, 0, i2);
    }

    void combinationRecursive(Map<Set<Device>, Integer> map, Device[] deviceArr, Device[] deviceArr2, int i, int i2, int i3, int i4) {
        if (i3 == i4) {
            map.put(new TreeSet(Arrays.asList(deviceArr2)), Integer.valueOf(computeCostOfDevices(deviceArr2)));
            return;
        }
        for (int i5 = i; i5 <= i2; i5++) {
            deviceArr2[i3] = deviceArr[i5];
            combinationRecursive(map, deviceArr, deviceArr2, i5 + 1, i2, i3 + 1, i4);
        }
    }

    @VisibleForTesting
    public int computeCostOfDevices(Device[] deviceArr) {
        int i = 0;
        for (int i2 = 0; i2 < deviceArr.length; i2++) {
            String valueOf = String.valueOf(deviceArr[i2].getMinorNumber());
            for (int i3 = i2 + 1; i3 < deviceArr.length; i3++) {
                i += this.devicePairToWeight.get(valueOf + "-" + String.valueOf(deviceArr[i3].getMinorNumber())).intValue();
            }
        }
        return i;
    }

    @VisibleForTesting
    public void topologyAwareSchedule(Set<Device> set, int i, Map<String, String> map, Set<Device> set2, Map<Integer, List<Map.Entry<Set<Device>, Integer>>> map2) {
        String str = map.get(TOPOLOGY_POLICY_ENV_KEY);
        if (str == null) {
            str = TOPOLOGY_POLICY_PACK;
        }
        if (map2 == null) {
            LOG.error("No cost table initialized!");
            return;
        }
        List<Map.Entry<Set<Device>, Integer>> list = map2.get(Integer.valueOf(i));
        Iterator<Map.Entry<Set<Device>, Integer>> it = list.iterator();
        if (str.equalsIgnoreCase(TOPOLOGY_POLICY_SPREAD)) {
            it = ((LinkedList) list).descendingIterator();
        }
        while (it.hasNext()) {
            Map.Entry<Set<Device>, Integer> next = it.next();
            if (set2.containsAll(next.getKey())) {
                set.addAll(next.getKey());
                LOG.info("Topology scheduler allocated: " + set);
                return;
            }
        }
        LOG.error("Unknown error happened in topology scheduler");
    }

    @VisibleForTesting
    public void basicSchedule(Set<Device> set, int i, Set<Device> set2) {
        if (i == set2.size()) {
            set.addAll(set2);
            return;
        }
        int i2 = 0;
        Iterator<Device> it = set2.iterator();
        while (it.hasNext()) {
            set.add(it.next());
            i2++;
            if (i2 == i) {
                return;
            }
        }
    }

    public void parseTopo(String str, Map<String, Integer> map) {
        for (String str2 : str.split(LoggingFeature.DEFAULT_SEPARATOR)) {
            String trim = str2.trim();
            if (!trim.isEmpty()) {
                if (trim.startsWith("Legend")) {
                    return;
                }
                if (!trim.contains("Affinity")) {
                    String[] split = trim.split("\\s+");
                    String str3 = split[0];
                    int parseInt = Integer.parseInt(str3.substring(str3.lastIndexOf("U") + 1));
                    for (int i = 1; i < split.length; i++) {
                        String str4 = split[i];
                        int i2 = i - 1;
                        if (!str4.equals("X")) {
                            if (str4.equals("SOC") || str4.equals("SYS")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkCrossCPUSocket, parseInt, i2, map);
                            } else if (str4.equals("PHB") || str4.equals("NODE")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkSameCPUSocket, parseInt, i2, map);
                            } else if (str4.equals("PXB")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkMultiSwitch, parseInt, i2, map);
                            } else if (str4.equals("PIX")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkSingleSwitch, parseInt, i2, map);
                            } else if (str4.equals("NV1")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink1, parseInt, i2, map);
                            } else if (str4.equals("NV2")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink2, parseInt, i2, map);
                            } else if (str4.equals("NV3")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink3, parseInt, i2, map);
                            } else if (str4.equals("NV4")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink4, parseInt, i2, map);
                            } else if (str4.equals("NV5")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink5, parseInt, i2, map);
                            } else if (str4.equals("NV6")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink6, parseInt, i2, map);
                            } else if (str4.equals("NV7")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink7, parseInt, i2, map);
                            } else if (str4.equals("NV8")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink8, parseInt, i2, map);
                            } else if (str4.equals("NV9")) {
                                populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink9, parseInt, i2, map);
                            }
                        }
                    }
                }
            }
        }
    }

    private void populateGraphEdgeWeight(DeviceLinkType deviceLinkType, int i, int i2, Map<String, Integer> map) {
        map.put(i + "-" + i2, Integer.valueOf(deviceLinkType.getWeight()));
    }

    @VisibleForTesting
    public void setPathOfGpuBinary(String str) {
        this.pathOfGpuBinary = str;
    }

    @VisibleForTesting
    public void setShellExecutor(NvidiaCommandExecutor nvidiaCommandExecutor) {
        this.shellExecutor = nvidiaCommandExecutor;
    }

    @VisibleForTesting
    public boolean isTopoInitialized() {
        return this.topoInitialized;
    }

    @VisibleForTesting
    public Map<Integer, List<Map.Entry<Set<Device>, Integer>>> getCostTable() {
        return this.costTable;
    }

    @VisibleForTesting
    public Map<String, Integer> getDevicePairToWeight() {
        return this.devicePairToWeight;
    }
}
