/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.mapreduce.v2;

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.v2.api.records.JobState;
import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptId;
import org.apache.hadoop.mapreduce.v2.api.records.TaskAttemptState;
import org.apache.hadoop.mapreduce.v2.api.records.TaskState;
import org.apache.hadoop.mapreduce.v2.app.MRApp;
import org.apache.hadoop.mapreduce.v2.app.job.Job;
import org.apache.hadoop.mapreduce.v2.app.job.Task;
import org.apache.hadoop.mapreduce.v2.app.job.TaskAttempt;
import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEvent;
import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEventType;
import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptStatusUpdateEvent;
import org.apache.hadoop.mapreduce.v2.app.speculate.LegacyTaskRuntimeEstimator;
import org.apache.hadoop.mapreduce.v2.app.speculate.SimpleExponentialTaskRuntimeEstimator;
import org.apache.hadoop.mapreduce.v2.app.speculate.TaskRuntimeEstimator;
import org.apache.hadoop.service.Service;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.util.Clock;
import org.apache.hadoop.yarn.util.ControlledClock;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

public class TestSpeculativeExecutionWithMRApp {
    private static final int NUM_MAPPERS = 5;
    private static final int NUM_REDUCERS = 0;
    private Class<? extends TaskRuntimeEstimator> estimatorClass;
    private ControlledClock controlledClk;

    public static Collection<Object[]> getTestParameters() {
        return Arrays.asList({SimpleExponentialTaskRuntimeEstimator.class}, {LegacyTaskRuntimeEstimator.class});
    }

    public void initTestSpeculativeExecutionWithMRApp(Class<? extends TaskRuntimeEstimator> pEstimatorKlass) {
        this.estimatorClass = pEstimatorKlass;
        this.controlledClk = new ControlledClock();
        this.setup();
    }

    public void setup() {
        this.controlledClk.setTime(System.currentTimeMillis());
    }

    @ParameterizedTest(name="{index}: TaskEstimator(EstimatorClass {0})")
    @MethodSource(value={"getTestParameters"})
    @Timeout(value=360L)
    public void testSpeculateSuccessfulWithoutUpdateEvents(Class<? extends TaskRuntimeEstimator> pEstimatorKlass) throws Exception {
        this.initTestSpeculativeExecutionWithMRApp(pEstimatorKlass);
        MRApp app = new MRApp(5, 0, false, "test", true, (Clock)this.controlledClk);
        Job job = app.submit(this.createConfiguration(), true, true);
        app.waitForState(job, JobState.RUNNING);
        Map tasks = job.getTasks();
        Assertions.assertEquals((int)5, (int)tasks.size(), (String)"Num tasks is not correct");
        Iterator taskIter = tasks.values().iterator();
        while (taskIter.hasNext()) {
            app.waitForState((Task)taskIter.next(), TaskState.RUNNING);
        }
        this.controlledClk.tickMsec(1000L);
        EventHandler appEventHandler = app.getContext().getEventHandler();
        for (Map.Entry mapTask : tasks.entrySet()) {
            for (Map.Entry entry : ((Task)mapTask.getValue()).getAttempts().entrySet()) {
                this.updateTaskProgress(appEventHandler, (TaskAttempt)entry.getValue(), 0.8f);
            }
        }
        Random generator = new Random();
        Object[] taskValues = tasks.values().toArray();
        Task taskToBeSpeculated = (Task)taskValues[generator.nextInt(taskValues.length)];
        for (Map.Entry entry : tasks.entrySet()) {
            if (entry.getKey() == taskToBeSpeculated.getID()) continue;
            for (Map.Entry taskAttempt : ((Task)entry.getValue()).getAttempts().entrySet()) {
                TaskAttemptId taId = (TaskAttemptId)taskAttempt.getKey();
                if (taId.getId() > 0) continue;
                TestSpeculativeExecutionWithMRApp.markTACompleted(appEventHandler, (TaskAttempt)taskAttempt.getValue());
                this.waitForTAState((TaskAttempt)taskAttempt.getValue(), TaskAttemptState.SUCCEEDED, this.controlledClk);
            }
        }
        this.controlledClk.tickMsec(2000L);
        this.waitForSpeculation(taskToBeSpeculated, this.controlledClk);
        TaskAttempt[] taskAttemptArray = TestSpeculativeExecutionWithMRApp.makeFirstAttemptWin(appEventHandler, taskToBeSpeculated);
        this.waitForTAState(taskAttemptArray[0], TaskAttemptState.SUCCEEDED, this.controlledClk);
        this.waitForAppStop(app, this.controlledClk);
    }

    @ParameterizedTest(name="{index}: TaskEstimator(EstimatorClass {0})")
    @MethodSource(value={"getTestParameters"})
    @Timeout(value=360L)
    public void testSpeculateSuccessfulWithUpdateEvents(Class<? extends TaskRuntimeEstimator> pEstimatorKlass) throws Exception {
        this.initTestSpeculativeExecutionWithMRApp(pEstimatorKlass);
        MRApp app = new MRApp(5, 0, false, "test", true, (Clock)this.controlledClk);
        Job job = app.submit(this.createConfiguration(), true, true);
        app.waitForState(job, JobState.RUNNING);
        Map tasks = job.getTasks();
        Assertions.assertEquals((int)5, (int)tasks.size(), (String)"Num tasks is not correct");
        Iterator taskIter = tasks.values().iterator();
        while (taskIter.hasNext()) {
            app.waitForState((Task)taskIter.next(), TaskState.RUNNING);
        }
        this.controlledClk.tickMsec(2000L);
        EventHandler appEventHandler = app.getContext().getEventHandler();
        for (Map.Entry mapTask : tasks.entrySet()) {
            for (Map.Entry taskAttempt : ((Task)mapTask.getValue()).getAttempts().entrySet()) {
                this.updateTaskProgress(appEventHandler, (TaskAttempt)taskAttempt.getValue(), 0.5f);
            }
        }
        Task speculatedTask = null;
        int numTasksToFinish = 4;
        this.controlledClk.tickMsec(1000L);
        for (Map.Entry task : tasks.entrySet()) {
            for (Map.Entry taskAttempt : ((Task)task.getValue()).getAttempts().entrySet()) {
                TaskAttemptId taId = (TaskAttemptId)taskAttempt.getKey();
                if (numTasksToFinish > 0 && taId.getId() == 0) {
                    TestSpeculativeExecutionWithMRApp.markTACompleted(appEventHandler, (TaskAttempt)taskAttempt.getValue());
                    --numTasksToFinish;
                    this.waitForTAState((TaskAttempt)taskAttempt.getValue(), TaskAttemptState.SUCCEEDED, this.controlledClk);
                    continue;
                }
                speculatedTask = (Task)task.getValue();
                this.updateTaskProgress(appEventHandler, (TaskAttempt)taskAttempt.getValue(), 0.75f);
            }
        }
        this.controlledClk.tickMsec(15000L);
        for (Map.Entry task : tasks.entrySet()) {
            for (Map.Entry taskAttempt : ((Task)task.getValue()).getAttempts().entrySet()) {
                if (((TaskAttempt)taskAttempt.getValue()).getState() == TaskAttemptState.SUCCEEDED || ((TaskAttempt)taskAttempt.getValue()).getState() == TaskAttemptState.KILLED) continue;
                this.updateTaskProgress(appEventHandler, (TaskAttempt)taskAttempt.getValue(), 0.75f);
            }
        }
        Task speculatedTaskConst = speculatedTask;
        this.waitForSpeculation(speculatedTaskConst, this.controlledClk);
        TaskAttempt[] ta = TestSpeculativeExecutionWithMRApp.makeFirstAttemptWin(appEventHandler, speculatedTask);
        this.waitForTAState(ta[0], TaskAttemptState.SUCCEEDED, this.controlledClk);
        this.waitForAppStop(app, this.controlledClk);
    }

    private static TaskAttempt[] makeFirstAttemptWin(EventHandler appEventHandler, Task speculatedTask) {
        Collection attempts = speculatedTask.getAttempts().values();
        TaskAttempt[] ta = new TaskAttempt[attempts.size()];
        attempts.toArray(ta);
        TestSpeculativeExecutionWithMRApp.markTACompleted(appEventHandler, ta[0]);
        return ta;
    }

    private static void markTACompleted(EventHandler appEventHandler, TaskAttempt attempt) {
        appEventHandler.handle((Event)new TaskAttemptEvent(attempt.getID(), TaskAttemptEventType.TA_DONE));
        appEventHandler.handle((Event)new TaskAttemptEvent(attempt.getID(), TaskAttemptEventType.TA_CONTAINER_COMPLETED));
    }

    private TaskAttemptStatusUpdateEvent.TaskAttemptStatus createTaskAttemptStatus(TaskAttemptId id, float progress, TaskAttemptState state) {
        TaskAttemptStatusUpdateEvent.TaskAttemptStatus status = new TaskAttemptStatusUpdateEvent.TaskAttemptStatus();
        status.id = id;
        status.progress = progress;
        status.taskState = state;
        return status;
    }

    private Configuration createConfiguration() {
        Configuration conf = new Configuration();
        conf.setClass("yarn.app.mapreduce.am.job.task.estimator.class", this.estimatorClass, TaskRuntimeEstimator.class);
        if (SimpleExponentialTaskRuntimeEstimator.class.equals(this.estimatorClass)) {
            conf.setInt("yarn.app.mapreduce.am.job.task.estimator.simple.exponential.smooth.skip-initials", 1);
            conf.setLong("yarn.app.mapreduce.am.job.task.estimator.simple.exponential.smooth.lambda-ms", 10000L);
        }
        conf.setLong("mapreduce.job.speculative.retry-after-no-speculate", 3000L);
        return conf;
    }

    private void waitForAppStop(MRApp app, ControlledClock cClock) throws TimeoutException, InterruptedException {
        GenericTestUtils.waitFor(() -> {
            if (app.getServiceState() != Service.STATE.STOPPED) {
                cClock.tickMsec(250L);
                return false;
            }
            return true;
        }, (long)250L, (long)60000L);
    }

    private void waitForSpeculation(Task speculatedTask, ControlledClock cClock) throws TimeoutException, InterruptedException {
        GenericTestUtils.waitFor(() -> {
            if (speculatedTask.getAttempts().size() != 2) {
                cClock.tickMsec(250L);
                return false;
            }
            return true;
        }, (long)250L, (long)60000L);
    }

    public void waitForTAState(TaskAttempt attempt, TaskAttemptState finalState, ControlledClock cClock) throws Exception {
        GenericTestUtils.waitFor(() -> {
            if (attempt.getReport().getTaskAttemptState() != finalState) {
                cClock.tickMsec(250L);
                return false;
            }
            return true;
        }, (long)250L, (long)10000L);
    }

    private void updateTaskProgress(EventHandler appEventHandler, TaskAttempt attempt, float newProgress) {
        TaskAttemptStatusUpdateEvent.TaskAttemptStatus status = this.createTaskAttemptStatus(attempt.getID(), newProgress, TaskAttemptState.RUNNING);
        TaskAttemptStatusUpdateEvent event = new TaskAttemptStatusUpdateEvent(attempt.getID(), new AtomicReference<TaskAttemptStatusUpdateEvent.TaskAttemptStatus>(status));
        appEventHandler.handle((Event)event);
    }
}

