package org.apache.tez.test;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hdfs.MiniDFSCluster;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.tez.client.TezClient;
import org.apache.tez.client.TezClientUtils;
import org.apache.tez.common.Preconditions;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.counters.DAGCounter;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.DAG;
import org.apache.tez.dag.api.DataSourceDescriptor;
import org.apache.tez.dag.api.Edge;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.ProcessorDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.Vertex;
import org.apache.tez.dag.api.client.DAGClient;
import org.apache.tez.dag.api.client.DAGStatus;
import org.apache.tez.dag.api.client.StatusGetOpts;
import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
import org.apache.tez.dag.app.RecoveryParser;
import org.apache.tez.dag.history.HistoryEvent;
import org.apache.tez.dag.history.HistoryEventType;
import org.apache.tez.dag.history.events.TaskAttemptFinishedEvent;
import org.apache.tez.mapreduce.input.MRInput;
import org.apache.tez.mapreduce.output.MROutput;
import org.apache.tez.mapreduce.processor.SimpleMRProcessor;
import org.apache.tez.runtime.api.LogicalInput;
import org.apache.tez.runtime.api.LogicalOutput;
import org.apache.tez.runtime.api.ProcessorContext;
import org.apache.tez.runtime.library.api.KeyValueReader;
import org.apache.tez.runtime.library.api.KeyValueWriter;
import org.apache.tez.runtime.library.api.KeyValuesReader;
import org.apache.tez.runtime.library.conf.OrderedPartitionedKVEdgeConfig;
import org.apache.tez.runtime.library.conf.UnorderedKVEdgeConfig;
import org.apache.tez.runtime.library.partitioner.HashPartitioner;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/tez/test/TestAMRecoveryAggregationBroadcast.class */
public class TestAMRecoveryAggregationBroadcast {
    private static final String INPUT1 = "Input";
    private static final String INPUT2 = "Input";
    private static final String OUTPUT = "Output";
    private static final String TABLE_SCAN = "TableScan";
    private static final String AGGREGATION = "Aggregation";
    private static final String MAP_JOIN = "MapJoin";
    private static final String EXPECTED_OUTPUT = "1-5\n1-5\n1-5\n1-5\n1-5\n2-4\n2-4\n2-4\n2-4\n3-3\n3-3\n3-3\n4-2\n4-2\n5-1\n";
    private static final String TABLE_SCAN_SLEEP = "tez.test.table.scan.sleep";
    private static final String AGGREGATION_SLEEP = "tez.test.aggregation.sleep";
    private static final String MAP_JOIN_SLEEP = "tez.test.map.join.sleep";
    private static Configuration dfsConf;
    private static MiniDFSCluster dfsCluster;
    private static MiniTezCluster tezCluster;
    private static FileSystem remoteFs;
    private TezConfiguration tezConf;
    private TezClient tezSession;
    private static final Logger LOG = LoggerFactory.getLogger(TestAMRecoveryAggregationBroadcast.class);
    private static final String TEST_ROOT_DIR = "target/" + TestAMRecoveryAggregationBroadcast.class.getName() + "-tmpDir";
    private static final Path INPUT_FILE = new Path(TEST_ROOT_DIR, "input.csv");
    private static final Path OUT_PATH = new Path(TEST_ROOT_DIR, "out-groups");

    /* loaded from: input_file:org/apache/tez/test/TestAMRecoveryAggregationBroadcast$AggregationProcessor.class */
    public static class AggregationProcessor extends SimpleMRProcessor {
        private final boolean sleep;

        public AggregationProcessor(ProcessorContext processorContext) {
            super(processorContext);
            try {
                this.sleep = TezUtils.createConfFromUserPayload(getContext().getUserPayload()).getBoolean(TestAMRecoveryAggregationBroadcast.AGGREGATION_SLEEP, false);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public void run() throws Exception {
            if (getContext().getDAGAttemptNumber() == 1 && this.sleep) {
                TimeUnit.SECONDS.sleep(60L);
            }
            Preconditions.checkArgument(getInputs().size() == 1);
            Preconditions.checkArgument(getOutputs().size() == 1);
            KeyValuesReader reader = ((LogicalInput) getInputs().get(TestAMRecoveryAggregationBroadcast.TABLE_SCAN)).getReader();
            KeyValueWriter writer = ((LogicalOutput) getOutputs().get(TestAMRecoveryAggregationBroadcast.MAP_JOIN)).getWriter();
            while (reader.next()) {
                Text text = (Text) reader.getCurrentKey();
                int i = 0;
                Iterator it = reader.getCurrentValues().iterator();
                while (it.hasNext()) {
                    i += ((IntWritable) it.next()).get();
                }
                writer.write(text, new IntWritable(i));
            }
        }
    }

    /* loaded from: input_file:org/apache/tez/test/TestAMRecoveryAggregationBroadcast$MapJoinProcessor.class */
    public static class MapJoinProcessor extends SimpleMRProcessor {
        private final boolean sleep;

        public MapJoinProcessor(ProcessorContext processorContext) {
            super(processorContext);
            try {
                this.sleep = TezUtils.createConfFromUserPayload(getContext().getUserPayload()).getBoolean(TestAMRecoveryAggregationBroadcast.MAP_JOIN_SLEEP, false);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public void run() throws Exception {
            if (getContext().getDAGAttemptNumber() == 1 && this.sleep) {
                TimeUnit.SECONDS.sleep(60L);
            }
            Preconditions.checkArgument(getInputs().size() == 2);
            Preconditions.checkArgument(getOutputs().size() == 1);
            KeyValueReader reader = ((LogicalInput) getInputs().get(TestAMRecoveryAggregationBroadcast.AGGREGATION)).getReader();
            HashMap hashMap = new HashMap();
            while (reader.next()) {
                hashMap.put(reader.getCurrentKey().toString(), Integer.valueOf(((IntWritable) reader.getCurrentValue()).get()));
            }
            KeyValueReader reader2 = ((LogicalInput) getInputs().get("Input")).getReader();
            KeyValueWriter writer = ((LogicalOutput) getOutputs().get(TestAMRecoveryAggregationBroadcast.OUTPUT)).getWriter();
            while (reader2.next()) {
                String obj = reader2.getCurrentValue().toString();
                writer.write(NullWritable.get(), String.format("%s-%d", obj, Integer.valueOf(((Integer) hashMap.getOrDefault(obj, 0)).intValue())));
            }
        }
    }

    /* loaded from: input_file:org/apache/tez/test/TestAMRecoveryAggregationBroadcast$TableScanProcessor.class */
    public static class TableScanProcessor extends SimpleMRProcessor {
        private static final IntWritable one = new IntWritable(1);
        private final boolean sleep;

        public TableScanProcessor(ProcessorContext processorContext) {
            super(processorContext);
            try {
                this.sleep = TezUtils.createConfFromUserPayload(getContext().getUserPayload()).getBoolean(TestAMRecoveryAggregationBroadcast.TABLE_SCAN_SLEEP, false);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public void run() throws Exception {
            if (getContext().getDAGAttemptNumber() == 1 && this.sleep) {
                TimeUnit.SECONDS.sleep(60L);
            }
            Preconditions.checkArgument(getInputs().size() == 1);
            Preconditions.checkArgument(getOutputs().size() == 1);
            KeyValueReader reader = ((LogicalInput) getInputs().get("Input")).getReader();
            KeyValueWriter writer = ((LogicalOutput) getOutputs().get(TestAMRecoveryAggregationBroadcast.AGGREGATION)).getWriter();
            while (reader.next()) {
                writer.write((Text) reader.getCurrentValue(), one);
            }
        }
    }

    @BeforeClass
    public static void setupAll() {
        try {
            dfsConf = new Configuration();
            dfsConf.set("hdfs.minidfs.basedir", TEST_ROOT_DIR);
            dfsCluster = new MiniDFSCluster.Builder(dfsConf).numDataNodes(3).format(true).racks((String[]) null).build();
            remoteFs = dfsCluster.getFileSystem();
            createSampleFile();
            if (tezCluster == null) {
                tezCluster = new MiniTezCluster(TestAMRecoveryAggregationBroadcast.class.getName(), 1, 1, 1);
                Configuration configuration = new Configuration(dfsConf);
                configuration.set("fs.defaultFS", remoteFs.getUri().toString());
                configuration.setInt("yarn.nodemanager.delete.debug-delay-sec", 20000);
                configuration.setLong("tez.am.sleep.time.before.exit.millis", 500L);
                tezCluster.init(configuration);
                tezCluster.start();
            }
        } catch (IOException e) {
            throw new RuntimeException("problem starting mini dfs cluster", e);
        }
    }

    private static void createSampleFile() throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(remoteFs.create(INPUT_FILE)));
        for (int i = 1; i <= 5; i++) {
            for (int i2 = 0; i2 <= 5 - i; i2++) {
                bufferedWriter.write(String.valueOf(i));
                bufferedWriter.newLine();
            }
        }
        bufferedWriter.close();
    }

    @AfterClass
    public static void tearDownAll() {
        if (tezCluster != null) {
            tezCluster.stop();
            tezCluster = null;
        }
        if (dfsCluster != null) {
            dfsCluster.shutdown(true);
            dfsCluster = null;
        }
    }

    @Before
    public void setup() throws Exception {
        Path makeQualified = remoteFs.makeQualified(new Path(TEST_ROOT_DIR, String.valueOf(new Random().nextInt(100000))));
        TezClientUtils.ensureStagingDirExists(dfsConf, makeQualified);
        this.tezConf = new TezConfiguration(tezCluster.getConfig());
        this.tezConf.setInt("tez.dag.recovery.max.unflushed.events", 0);
        this.tezConf.set("tez.am.log.level", "INFO");
        this.tezConf.set("tez.staging-dir", makeQualified.toString());
        this.tezConf.setInt("tez.am.resource.memory.mb", 500);
        this.tezConf.set("tez.am.launch.cmd-opts", " -Xmx256m");
        this.tezConf.setBoolean("tez.am.staging.scratch-data.auto-delete", false);
        this.tezConf.setBoolean("tez.test.recovery.drain_event", true);
        this.tezSession = TezClient.create("TestAMRecoveryAggregationBroadcast", this.tezConf);
        this.tezSession.start();
    }

    @After
    public void teardown() throws InterruptedException {
        if (this.tezSession != null) {
            try {
                LOG.info("Stopping Tez Session");
                this.tezSession.stop();
            } catch (Exception e) {
                LOG.error("Failed to stop Tez session", e);
            }
        }
        this.tezSession = null;
    }

    @Test(timeout = 120000)
    public void testSucceed() throws Exception {
        Assert.assertEquals(3L, runDAGAndVerify(createDAG("Succeed"), false).findCounter(DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> readRecoveryLog = readRecoveryLog(1);
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog, 0, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog, 1, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), readRecoveryLog(2));
    }

    @Test(timeout = 120000)
    public void testTableScanTemporalFailure() throws Exception {
        this.tezConf.setBoolean(TABLE_SCAN_SLEEP, true);
        Assert.assertEquals(3L, runDAGAndVerify(createDAG("TableScanTemporalFailure"), true).findCounter(DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> readRecoveryLog = readRecoveryLog(1);
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog, 0, 0).size());
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog, 1, 0).size());
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog, 2, 0).size());
        List<HistoryEvent> readRecoveryLog2 = readRecoveryLog(2);
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog2, 0, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog2, 1, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog2, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), readRecoveryLog(3));
    }

    @Test(timeout = 120000)
    public void testAggregationTemporalFailure() throws Exception {
        this.tezConf.setBoolean(AGGREGATION_SLEEP, true);
        Assert.assertEquals(3L, runDAGAndVerify(createDAG("AggregationTemporalFailure"), true).findCounter(DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> readRecoveryLog = readRecoveryLog(1);
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog, 0, 0).size());
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog, 1, 0).size());
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog, 2, 0).size());
        List<HistoryEvent> readRecoveryLog2 = readRecoveryLog(2);
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog2, 0, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog2, 1, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog2, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), readRecoveryLog(3));
    }

    @Test(timeout = 120000)
    public void testMapJoinTemporalFailure() throws Exception {
        this.tezConf.setBoolean(MAP_JOIN_SLEEP, true);
        Assert.assertEquals(3L, runDAGAndVerify(createDAG("MapJoinTemporalFailure"), true).findCounter(DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> readRecoveryLog = readRecoveryLog(1);
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog, 0, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog, 1, 0).size());
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog, 2, 0).size());
        List<HistoryEvent> readRecoveryLog2 = readRecoveryLog(2);
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog2, 0, 0).size());
        Assert.assertEquals(0L, findTaskAttemptFinishedEvent(readRecoveryLog2, 1, 0).size());
        Assert.assertEquals(1L, findTaskAttemptFinishedEvent(readRecoveryLog2, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), readRecoveryLog(3));
    }

    private DAG createDAG(String str) throws Exception {
        UserPayload createUserPayloadFromConf = TezUtils.createUserPayloadFromConf(this.tezConf);
        DataSourceDescriptor build = MRInput.createConfigBuilder(new Configuration(this.tezConf), TextInputFormat.class, INPUT_FILE.toString()).build();
        Vertex addDataSource = Vertex.create(TABLE_SCAN, ProcessorDescriptor.create(TableScanProcessor.class.getName()).setUserPayload(createUserPayloadFromConf)).addDataSource("Input", build);
        Vertex create = Vertex.create(AGGREGATION, ProcessorDescriptor.create(AggregationProcessor.class.getName()).setUserPayload(createUserPayloadFromConf), 1);
        Vertex addDataSink = Vertex.create(MAP_JOIN, ProcessorDescriptor.create(MapJoinProcessor.class.getName()).setUserPayload(createUserPayloadFromConf)).addDataSource("Input", build).addDataSink(OUTPUT, MROutput.createConfigBuilder(new Configuration(this.tezConf), TextOutputFormat.class, OUT_PATH.toString()).build());
        EdgeProperty createDefaultEdgeProperty = OrderedPartitionedKVEdgeConfig.newBuilder(Text.class.getName(), IntWritable.class.getName(), HashPartitioner.class.getName()).setFromConfiguration(this.tezConf).build().createDefaultEdgeProperty();
        EdgeProperty createDefaultBroadcastEdgeProperty = UnorderedKVEdgeConfig.newBuilder(Text.class.getName(), IntWritable.class.getName()).setFromConfiguration(this.tezConf).build().createDefaultBroadcastEdgeProperty();
        DAG create2 = DAG.create("TestAMRecoveryAggregationBroadcast_" + str);
        create2.addVertex(addDataSource).addVertex(create).addVertex(addDataSink).addEdge(Edge.create(addDataSource, create, createDefaultEdgeProperty)).addEdge(Edge.create(create, addDataSink, createDefaultBroadcastEdgeProperty));
        return create2;
    }

    TezCounters runDAGAndVerify(DAG dag, boolean z) throws Exception {
        this.tezSession.waitTillReady();
        DAGClient submitDAG = this.tezSession.submitDAG(dag);
        if (z) {
            TimeUnit.SECONDS.sleep(10L);
            YarnClient createYarnClient = YarnClient.createYarnClient();
            createYarnClient.init(this.tezConf);
            createYarnClient.start();
            createYarnClient.failApplicationAttempt(ApplicationAttemptId.newInstance(this.tezSession.getAppMasterApplicationId(), 1));
            createYarnClient.close();
        }
        DAGStatus waitForCompletionWithStatusUpdates = submitDAG.waitForCompletionWithStatusUpdates(EnumSet.of(StatusGetOpts.GET_COUNTERS));
        LOG.info("Diagnosis: " + waitForCompletionWithStatusUpdates.getDiagnostics());
        Assert.assertEquals(DAGStatus.State.SUCCEEDED, waitForCompletionWithStatusUpdates.getState());
        FSDataInputStream open = remoteFs.open(new Path(OUT_PATH, "part-v002-o000-r-00000"));
        ByteBuffer allocate = ByteBuffer.allocate(100);
        open.read(allocate);
        allocate.flip();
        Assert.assertEquals(EXPECTED_OUTPUT, StandardCharsets.UTF_8.decode(allocate).toString());
        return waitForCompletionWithStatusUpdates.getDAGCounters();
    }

    private List<HistoryEvent> readRecoveryLog(int i) throws IOException {
        ApplicationId appMasterApplicationId = this.tezSession.getAppMasterApplicationId();
        Path tezSystemStagingPath = TezCommonUtils.getTezSystemStagingPath(this.tezConf, appMasterApplicationId.toString());
        Path recoveryPath = TezCommonUtils.getRecoveryPath(tezSystemStagingPath, this.tezConf);
        FileSystem fileSystem = tezSystemStagingPath.getFileSystem(this.tezConf);
        ArrayList arrayList = new ArrayList();
        Path path = new Path(TezCommonUtils.getAttemptRecoveryPath(recoveryPath, i), appMasterApplicationId.toString().replace("application", "dag") + "_1.recovery");
        if (fileSystem.exists(path)) {
            LOG.info("Read recovery file:" + path);
            arrayList.addAll(RecoveryParser.parseDAGRecoveryFile(fileSystem.open(path)));
        }
        printHistoryEvents(arrayList, i);
        return arrayList;
    }

    private void printHistoryEvents(List<HistoryEvent> list, int i) {
        LOG.info("RecoveryLogs from attempt:" + i);
        for (HistoryEvent historyEvent : list) {
            LOG.info("Parsed event from recovery stream, eventType=" + historyEvent.getEventType() + ", event=" + historyEvent);
        }
        LOG.info("");
    }

    private List<TaskAttemptFinishedEvent> findTaskAttemptFinishedEvent(List<HistoryEvent> list, int i, int i2) {
        ArrayList arrayList = new ArrayList();
        Iterator<HistoryEvent> it = list.iterator();
        while (it.hasNext()) {
            TaskAttemptFinishedEvent taskAttemptFinishedEvent = (HistoryEvent) it.next();
            if (taskAttemptFinishedEvent.getEventType() == HistoryEventType.TASK_ATTEMPT_FINISHED) {
                TaskAttemptFinishedEvent taskAttemptFinishedEvent2 = taskAttemptFinishedEvent;
                if (taskAttemptFinishedEvent2.getState() != TaskAttemptState.KILLED && taskAttemptFinishedEvent2.getVertexID().getId() == i && taskAttemptFinishedEvent2.getTaskID().getId() == i2) {
                    arrayList.add(taskAttemptFinishedEvent2);
                }
            }
        }
        return arrayList;
    }
}
