package org.apache.spark.network.util;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.network.util.TransportFrameDecoder;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/spark/network/util/TransportFrameDecoderSuite.class */
public class TransportFrameDecoderSuite {
    private static final Logger logger = LoggerFactory.getLogger(TransportFrameDecoderSuite.class);
    private static Random RND = new Random();

    /* loaded from: input_file:org/apache/spark/network/util/TransportFrameDecoderSuite$MockInterceptor.class */
    private static class MockInterceptor implements TransportFrameDecoder.Interceptor {
        private int remainingReads;

        MockInterceptor(int i) {
            this.remainingReads = i;
        }

        public boolean handle(ByteBuf byteBuf) throws Exception {
            byteBuf.readerIndex(byteBuf.readerIndex() + byteBuf.readableBytes());
            Assert.assertFalse(byteBuf.isReadable());
            this.remainingReads--;
            return this.remainingReads != 0;
        }

        public void exceptionCaught(Throwable th) throws Exception {
        }

        public void channelInactive() throws Exception {
        }
    }

    @AfterClass
    public static void cleanup() {
        RND = null;
    }

    @Test
    public void testFrameDecoding() throws Exception {
        TransportFrameDecoder transportFrameDecoder = new TransportFrameDecoder();
        ChannelHandlerContext mockChannelHandlerContext = mockChannelHandlerContext();
        verifyAndCloseDecoder(transportFrameDecoder, mockChannelHandlerContext, createAndFeedFrames(100, transportFrameDecoder, mockChannelHandlerContext));
    }

    @Test
    public void testConsolidationPerf() throws Exception {
        for (long j : new long[]{ByteUnit.MiB.toBytes(1L), ByteUnit.MiB.toBytes(5L), ByteUnit.MiB.toBytes(10L), ByteUnit.MiB.toBytes(20L), ByteUnit.MiB.toBytes(30L), ByteUnit.MiB.toBytes(50L), ByteUnit.MiB.toBytes(80L), ByteUnit.MiB.toBytes(100L), ByteUnit.MiB.toBytes(300L), ByteUnit.MiB.toBytes(500L), Long.MAX_VALUE}) {
            TransportFrameDecoder transportFrameDecoder = new TransportFrameDecoder(j);
            ChannelHandlerContext channelHandlerContext = (ChannelHandlerContext) Mockito.mock(ChannelHandlerContext.class);
            ArrayList arrayList = new ArrayList();
            Mockito.when(channelHandlerContext.fireChannelRead(Mockito.any())).thenAnswer(invocationOnMock -> {
                arrayList.add((ByteBuf) invocationOnMock.getArguments()[0]);
                return null;
            });
            long bytes = ByteUnit.MiB.toBytes(300L);
            int bytes2 = (int) ByteUnit.KiB.toBytes(32L);
            for (int i = 0; i < 3; i++) {
                try {
                    long j2 = 0;
                    ByteBuf buffer = Unpooled.buffer(8);
                    buffer.writeLong(8 + bytes);
                    transportFrameDecoder.channelRead(channelHandlerContext, buffer);
                    for (long j3 = 0; j3 < bytes; j3 += bytes2) {
                        ByteBuf buffer2 = Unpooled.buffer(bytes2 * 2);
                        ByteBuf writerIndex = Unpooled.buffer(bytes2).writerIndex(bytes2);
                        buffer2.writeBytes(writerIndex);
                        writerIndex.release();
                        long currentTimeMillis = System.currentTimeMillis();
                        transportFrameDecoder.channelRead(channelHandlerContext, buffer2);
                        j2 += System.currentTimeMillis() - currentTimeMillis;
                    }
                    Logger logger2 = logger;
                    logger2.info("Writing 300MiB frame buf with consolidation of threshold " + j + " took " + logger2 + " millis");
                    Iterator it = arrayList.iterator();
                    while (it.hasNext()) {
                        release((ByteBuf) it.next());
                    }
                } catch (Throwable th) {
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        release((ByteBuf) it2.next());
                    }
                    throw th;
                }
            }
            long j4 = 0;
            while (arrayList.iterator().hasNext()) {
                j4 += ((ByteBuf) r0.next()).capacity();
            }
            Assert.assertEquals(3, arrayList.size());
            Assert.assertEquals(bytes * 3, j4);
        }
    }

    @Test
    public void testInterception() throws Exception {
        TransportFrameDecoder transportFrameDecoder = new TransportFrameDecoder();
        TransportFrameDecoder.Interceptor interceptor = (TransportFrameDecoder.Interceptor) Mockito.spy(new MockInterceptor(3));
        ChannelHandlerContext mockChannelHandlerContext = mockChannelHandlerContext();
        byte[] bArr = new byte[8];
        ByteBuf copyLong = Unpooled.copyLong(8 + bArr.length);
        ByteBuf wrappedBuffer = Unpooled.wrappedBuffer(bArr);
        try {
            transportFrameDecoder.setInterceptor(interceptor);
            for (int i = 0; i < 3; i++) {
                transportFrameDecoder.channelRead(mockChannelHandlerContext, wrappedBuffer);
                Assert.assertEquals(0L, wrappedBuffer.refCnt());
                wrappedBuffer = Unpooled.wrappedBuffer(bArr);
            }
            transportFrameDecoder.channelRead(mockChannelHandlerContext, copyLong);
            transportFrameDecoder.channelRead(mockChannelHandlerContext, wrappedBuffer);
            ((TransportFrameDecoder.Interceptor) Mockito.verify(interceptor, Mockito.times(3))).handle((ByteBuf) Mockito.any(ByteBuf.class));
            ((ChannelHandlerContext) Mockito.verify(mockChannelHandlerContext)).fireChannelRead(Mockito.any(ByteBuf.class));
            Assert.assertEquals(0L, copyLong.refCnt());
            Assert.assertEquals(0L, wrappedBuffer.refCnt());
            release(copyLong);
            release(wrappedBuffer);
        } catch (Throwable th) {
            release(copyLong);
            release(wrappedBuffer);
            throw th;
        }
    }

    @Test
    public void testRetainedFrames() throws Exception {
        TransportFrameDecoder transportFrameDecoder = new TransportFrameDecoder();
        AtomicInteger atomicInteger = new AtomicInteger();
        ArrayList<ByteBuf> arrayList = new ArrayList();
        ChannelHandlerContext channelHandlerContext = (ChannelHandlerContext) Mockito.mock(ChannelHandlerContext.class);
        Mockito.when(channelHandlerContext.fireChannelRead(Mockito.any())).thenAnswer(invocationOnMock -> {
            ByteBuf byteBuf = (ByteBuf) invocationOnMock.getArguments()[0];
            if (atomicInteger.incrementAndGet() % 2 == 0) {
                arrayList.add(byteBuf);
                return null;
            }
            byteBuf.release();
            return null;
        });
        ByteBuf createAndFeedFrames = createAndFeedFrames(100, transportFrameDecoder, channelHandlerContext);
        try {
            for (ByteBuf byteBuf : arrayList) {
                byteBuf.readBytes(new byte[byteBuf.readableBytes()]);
                byteBuf.release();
            }
            verifyAndCloseDecoder(transportFrameDecoder, channelHandlerContext, createAndFeedFrames);
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                release((ByteBuf) it.next());
            }
        } catch (Throwable th) {
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                release((ByteBuf) it2.next());
            }
            throw th;
        }
    }

    @Test
    public void testSplitLengthField() throws Exception {
        byte[] bArr = new byte[1024 * (RND.nextInt(31) + 1)];
        ByteBuf buffer = Unpooled.buffer(bArr.length + 8);
        buffer.writeLong(bArr.length + 8);
        buffer.writeBytes(bArr);
        TransportFrameDecoder transportFrameDecoder = new TransportFrameDecoder();
        ChannelHandlerContext mockChannelHandlerContext = mockChannelHandlerContext();
        try {
            transportFrameDecoder.channelRead(mockChannelHandlerContext, buffer.readSlice(RND.nextInt(7)).retain());
            ((ChannelHandlerContext) Mockito.verify(mockChannelHandlerContext, Mockito.never())).fireChannelRead(Mockito.any(ByteBuf.class));
            transportFrameDecoder.channelRead(mockChannelHandlerContext, buffer);
            ((ChannelHandlerContext) Mockito.verify(mockChannelHandlerContext)).fireChannelRead(Mockito.any(ByteBuf.class));
            Assert.assertEquals(0L, buffer.refCnt());
            transportFrameDecoder.channelInactive(mockChannelHandlerContext);
            release(buffer);
        } catch (Throwable th) {
            transportFrameDecoder.channelInactive(mockChannelHandlerContext);
            release(buffer);
            throw th;
        }
    }

    @Test
    public void testNegativeFrameSize() {
        Assert.assertThrows(IllegalArgumentException.class, () -> {
            testInvalidFrame(-1L);
        });
    }

    @Test
    public void testEmptyFrame() {
        Assert.assertThrows(IllegalArgumentException.class, () -> {
            testInvalidFrame(8L);
        });
    }

    private ByteBuf createAndFeedFrames(int i, TransportFrameDecoder transportFrameDecoder, ChannelHandlerContext channelHandlerContext) throws Exception {
        ByteBuf buffer = Unpooled.buffer();
        for (int i2 = 0; i2 < i; i2++) {
            byte[] bArr = new byte[1024 * (RND.nextInt(31) + 1)];
            buffer.writeLong(bArr.length + 8);
            buffer.writeBytes(bArr);
        }
        while (buffer.isReadable()) {
            try {
                transportFrameDecoder.channelRead(channelHandlerContext, buffer.readSlice(Math.min(buffer.readableBytes(), RND.nextInt(4096) + 256)).retain());
            } catch (Exception e) {
                release(buffer);
                throw e;
            }
        }
        ((ChannelHandlerContext) Mockito.verify(channelHandlerContext, Mockito.times(i))).fireChannelRead(Mockito.any(ByteBuf.class));
        return buffer;
    }

    private void verifyAndCloseDecoder(TransportFrameDecoder transportFrameDecoder, ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf) throws Exception {
        try {
            transportFrameDecoder.channelInactive(channelHandlerContext);
            Assert.assertTrue("There shouldn't be dangling references to the data.", byteBuf.release());
            release(byteBuf);
        } catch (Throwable th) {
            release(byteBuf);
            throw th;
        }
    }

    private void testInvalidFrame(long j) throws Exception {
        TransportFrameDecoder transportFrameDecoder = new TransportFrameDecoder();
        ChannelHandlerContext channelHandlerContext = (ChannelHandlerContext) Mockito.mock(ChannelHandlerContext.class);
        ByteBuf copyLong = Unpooled.copyLong(j);
        try {
            transportFrameDecoder.channelRead(channelHandlerContext, copyLong);
            release(copyLong);
        } catch (Throwable th) {
            release(copyLong);
            throw th;
        }
    }

    private ChannelHandlerContext mockChannelHandlerContext() {
        ChannelHandlerContext channelHandlerContext = (ChannelHandlerContext) Mockito.mock(ChannelHandlerContext.class);
        Mockito.when(channelHandlerContext.fireChannelRead(Mockito.any())).thenAnswer(invocationOnMock -> {
            ((ByteBuf) invocationOnMock.getArguments()[0]).release();
            return null;
        });
        return channelHandlerContext;
    }

    private void release(ByteBuf byteBuf) {
        if (byteBuf.refCnt() > 0) {
            byteBuf.release(byteBuf.refCnt());
        }
    }
}
