/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.api.common.functions.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.TaskInfoImpl;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class RuntimeUDFContextTest {
    private final TaskInfo taskInfo = new TaskInfoImpl("test name", 3, 1, 3, 0);

    RuntimeUDFContextTest() {
    }

    @Test
    void testBroadcastVariableNotFound() {
        RuntimeUDFContext ctx = new RuntimeUDFContext(this.taskInfo, this.getClass().getClassLoader(), new ExecutionConfig(), new HashMap(), new HashMap(), UnregisteredMetricsGroup.createOperatorMetricGroup());
        Assertions.assertThat((boolean)ctx.hasBroadcastVariable("some name")).isFalse();
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ctx.getBroadcastVariable("some name")).isInstanceOf(IllegalArgumentException.class)).hasMessageContaining("some name");
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> ctx.getBroadcastVariableWithInitializer("some name", data -> null)).isInstanceOf(IllegalArgumentException.class)).hasMessageContaining("some name");
    }

    @Test
    void testBroadcastVariableSimple() {
        RuntimeUDFContext ctx = new RuntimeUDFContext(this.taskInfo, this.getClass().getClassLoader(), new ExecutionConfig(), new HashMap(), new HashMap(), UnregisteredMetricsGroup.createOperatorMetricGroup());
        ctx.setBroadcastVariable("name1", Arrays.asList(1, 2, 3, 4));
        ctx.setBroadcastVariable("name2", Arrays.asList(1.0, 2.0, 3.0, 4.0));
        Assertions.assertThat((boolean)ctx.hasBroadcastVariable("name1")).isTrue();
        Assertions.assertThat((boolean)ctx.hasBroadcastVariable("name2")).isTrue();
        List list1 = ctx.getBroadcastVariable("name1");
        List list2 = ctx.getBroadcastVariable("name2");
        Assertions.assertThat((List)list1).isEqualTo(Arrays.asList(1, 2, 3, 4));
        Assertions.assertThat((List)list2).isEqualTo(Arrays.asList(1.0, 2.0, 3.0, 4.0));
        List list3 = ctx.getBroadcastVariable("name1");
        List list4 = ctx.getBroadcastVariable("name2");
        Assertions.assertThat((List)list3).isEqualTo(Arrays.asList(1, 2, 3, 4));
        Assertions.assertThat((List)list4).isEqualTo(Arrays.asList(1.0, 2.0, 3.0, 4.0));
        List list5 = ctx.getBroadcastVariable("name1");
        List list6 = ctx.getBroadcastVariable("name2");
        Assertions.assertThat((List)list5).isEqualTo(Arrays.asList(1, 2, 3, 4));
        Assertions.assertThat((List)list6).isEqualTo(Arrays.asList(1.0, 2.0, 3.0, 4.0));
    }

    @Test
    void testBroadcastVariableWithInitializer() {
        RuntimeUDFContext ctx = new RuntimeUDFContext(this.taskInfo, this.getClass().getClassLoader(), new ExecutionConfig(), new HashMap(), new HashMap(), UnregisteredMetricsGroup.createOperatorMetricGroup());
        ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));
        List list = (List)ctx.getBroadcastVariableWithInitializer("name", (BroadcastVariableInitializer)new ConvertingInitializer());
        Assertions.assertThat((List)list).isEqualTo(Arrays.asList(1.0, 2.0, 3.0, 4.0));
        List list2 = (List)ctx.getBroadcastVariableWithInitializer("name", (BroadcastVariableInitializer)new ConvertingInitializer());
        Assertions.assertThat((List)list2).isEqualTo(Arrays.asList(1.0, 2.0, 3.0, 4.0));
        List list3 = ctx.getBroadcastVariable("name");
        Assertions.assertThat((List)list3).isEqualTo(Arrays.asList(1.0, 2.0, 3.0, 4.0));
    }

    @Test
    void testResetBroadcastVariableWithInitializer() {
        RuntimeUDFContext ctx = new RuntimeUDFContext(this.taskInfo, this.getClass().getClassLoader(), new ExecutionConfig(), new HashMap(), new HashMap(), UnregisteredMetricsGroup.createOperatorMetricGroup());
        ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));
        List list = (List)ctx.getBroadcastVariableWithInitializer("name", (BroadcastVariableInitializer)new ConvertingInitializer());
        Assertions.assertThat((List)list).isEqualTo(Arrays.asList(1.0, 2.0, 3.0, 4.0));
        ctx.setBroadcastVariable("name", Arrays.asList(2, 3, 4, 5));
        List list2 = (List)ctx.getBroadcastVariableWithInitializer("name", (BroadcastVariableInitializer)new ConvertingInitializer());
        Assertions.assertThat((List)list2).isEqualTo(Arrays.asList(2.0, 3.0, 4.0, 5.0));
    }

    @Test
    void testBroadcastVariableWithInitializerAndMismatch() {
        RuntimeUDFContext ctx = new RuntimeUDFContext(this.taskInfo, this.getClass().getClassLoader(), new ExecutionConfig(), new HashMap(), new HashMap(), UnregisteredMetricsGroup.createOperatorMetricGroup());
        ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));
        int sum = (Integer)ctx.getBroadcastVariableWithInitializer("name", (BroadcastVariableInitializer)new SumInitializer());
        Assertions.assertThat((int)sum).isEqualTo(10);
        Assertions.assertThatThrownBy(() -> ctx.getBroadcastVariable("name")).isInstanceOf(IllegalStateException.class);
    }

    private static final class SumInitializer
    implements BroadcastVariableInitializer<Integer, Integer> {
        private SumInitializer() {
        }

        public Integer initializeBroadcastVariable(Iterable<Integer> data) {
            int sum = 0;
            for (Integer i : data) {
                sum += i.intValue();
            }
            return sum;
        }
    }

    private static final class ConvertingInitializer
    implements BroadcastVariableInitializer<Integer, List<Double>> {
        private ConvertingInitializer() {
        }

        public List<Double> initializeBroadcastVariable(Iterable<Integer> data) {
            ArrayList<Double> list = new ArrayList<Double>();
            for (Integer i : data) {
                list.add(i.doubleValue());
            }
            return list;
        }
    }
}

