package org.apache.spark.network.shuffle;

import com.google.common.collect.Lists;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.nio.NioSocketChannel;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.MessageEncoder;
import org.apache.spark.network.protocol.MessageWithHeader;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.shuffle.ShuffleTransportContext;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/spark/network/shuffle/ShuffleTransportContextSuite.class */
public class ShuffleTransportContextSuite {
    private ExternalBlockHandler blockHandler;

    @BeforeEach
    public void before() throws IOException {
        this.blockHandler = (ExternalBlockHandler) Mockito.mock(ExternalBlockHandler.class);
    }

    protected TransportConf createTransportConf(boolean z) {
        HashMap hashMap = new HashMap();
        hashMap.put("spark.shuffle.server.finalizeShuffleMergeThreadsPercent", z ? "1" : "0");
        return new TransportConf("shuffle", new MapConfigProvider(hashMap));
    }

    ShuffleTransportContext createShuffleTransportContext(boolean z) throws IOException {
        return new ShuffleTransportContext(createTransportConf(z), this.blockHandler, true);
    }

    private ByteBuf getDecodableMessageBuf(Message message) throws Exception {
        ArrayList newArrayList = Lists.newArrayList();
        ChannelHandlerContext channelHandlerContext = (ChannelHandlerContext) Mockito.mock(ChannelHandlerContext.class);
        Mockito.when(channelHandlerContext.alloc()).thenReturn(ByteBufAllocator.DEFAULT);
        MessageEncoder.INSTANCE.encode(channelHandlerContext, message, newArrayList);
        MessageWithHeader messageWithHeader = (MessageWithHeader) newArrayList.remove(0);
        ByteArrayWritableChannel byteArrayWritableChannel = new ByteArrayWritableChannel((int) messageWithHeader.count());
        while (messageWithHeader.transfered() < messageWithHeader.count()) {
            messageWithHeader.transferTo(byteArrayWritableChannel, messageWithHeader.transfered());
        }
        ByteBuf wrappedBuffer = Unpooled.wrappedBuffer(byteArrayWritableChannel.getData());
        wrappedBuffer.readLong();
        return wrappedBuffer;
    }

    @Test
    public void testInitializePipeline() throws IOException {
        for (boolean z : new boolean[]{true, false}) {
            for (boolean z2 : new boolean[]{true, false}) {
                ShuffleTransportContext createShuffleTransportContext = createShuffleTransportContext(z);
                try {
                    NioSocketChannel nioSocketChannel = new NioSocketChannel();
                    createShuffleTransportContext.initializePipeline(nioSocketChannel, (RpcHandler) Mockito.mock(RpcHandler.class), z2);
                    if (z) {
                        Assertions.assertNotNull(nioSocketChannel.pipeline().get("finalizeHandler"));
                    } else {
                        Assertions.assertNull(nioSocketChannel.pipeline().get("finalizeHandler"));
                    }
                    if (createShuffleTransportContext != null) {
                        createShuffleTransportContext.close();
                    }
                } catch (Throwable th) {
                    if (createShuffleTransportContext != null) {
                        try {
                            createShuffleTransportContext.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            }
        }
    }

    @Test
    public void testDecodeOfFinalizeShuffleMessage() throws Exception {
        ByteBuf decodableMessageBuf = getDecodableMessageBuf(new RpcRequest(1L, new NioManagedBuffer(new FinalizeShuffleMerge("app0", 1, 2, 3).toByteBuffer())));
        ShuffleTransportContext createShuffleTransportContext = createShuffleTransportContext(true);
        try {
            ShuffleTransportContext.ShuffleMessageDecoder decoder = createShuffleTransportContext.getDecoder();
            ArrayList newArrayList = Lists.newArrayList();
            decoder.decode((ChannelHandlerContext) Mockito.mock(ChannelHandlerContext.class), decodableMessageBuf, newArrayList);
            Assertions.assertEquals(1, newArrayList.size());
            Assertions.assertInstanceOf(ShuffleTransportContext.RpcRequestInternal.class, newArrayList.get(0));
            Assertions.assertEquals(BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE, ((ShuffleTransportContext.RpcRequestInternal) newArrayList.get(0)).messageType());
            if (createShuffleTransportContext != null) {
                createShuffleTransportContext.close();
            }
        } catch (Throwable th) {
            if (createShuffleTransportContext != null) {
                try {
                    createShuffleTransportContext.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testDecodeOfAnyOtherRpcMessage() throws Exception {
        RpcRequest rpcRequest = new RpcRequest(1L, new NioManagedBuffer(new OpenBlocks("app0", "1", new String[]{"block1", "block2"}).toByteBuffer()));
        ByteBuf decodableMessageBuf = getDecodableMessageBuf(rpcRequest);
        ShuffleTransportContext createShuffleTransportContext = createShuffleTransportContext(true);
        try {
            ShuffleTransportContext.ShuffleMessageDecoder decoder = createShuffleTransportContext.getDecoder();
            ArrayList newArrayList = Lists.newArrayList();
            decoder.decode((ChannelHandlerContext) Mockito.mock(ChannelHandlerContext.class), decodableMessageBuf, newArrayList);
            Assertions.assertEquals(1, newArrayList.size());
            Assertions.assertInstanceOf(RpcRequest.class, newArrayList.get(0));
            Assertions.assertEquals(rpcRequest.requestId, ((RpcRequest) newArrayList.get(0)).requestId);
            if (createShuffleTransportContext != null) {
                createShuffleTransportContext.close();
            }
        } catch (Throwable th) {
            if (createShuffleTransportContext != null) {
                try {
                    createShuffleTransportContext.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
