package org.apache.flink.streaming.tests.artificialstate;

import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/streaming/tests/artificialstate/ArtificalOperatorStateMapper.class */
public class ArtificalOperatorStateMapper<IN, OUT> extends RichMapFunction<IN, OUT> implements CheckpointedFunction {
    private static final long serialVersionUID = -1741298597425077761L;
    private static final String LAST_NUM_SUBTASKS_STATE_NAME = "lastNumSubtasksState";
    private static final String BROADCAST_STATE_NAME = "broadcastState";
    private static final String UNION_STATE_NAME = "unionState";
    private static final String LAST_NUM_SUBTASKS_STATE_KEY = "lastNumSubtasks";
    private static final String BROADCAST_STATE_ENTRY_VALUE_PREFIX = "broadcastStateEntry-";
    private final MapFunction<IN, OUT> mapFunction;
    private transient BroadcastState<String, Integer> lastNumSubtasksBroadcastState;
    private transient BroadcastState<Integer, String> broadcastElementsState;
    private transient ListState<Integer> unionElementsState;

    public ArtificalOperatorStateMapper(MapFunction<IN, OUT> mapFunction) {
        this.mapFunction = (MapFunction) Preconditions.checkNotNull(mapFunction);
    }

    public OUT map(IN in) throws Exception {
        return (OUT) this.mapFunction.map(in);
    }

    public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
        this.lastNumSubtasksBroadcastState = functionInitializationContext.getOperatorStateStore().getBroadcastState(new MapStateDescriptor(LAST_NUM_SUBTASKS_STATE_NAME, StringSerializer.INSTANCE, IntSerializer.INSTANCE));
        this.broadcastElementsState = functionInitializationContext.getOperatorStateStore().getBroadcastState(new MapStateDescriptor(BROADCAST_STATE_NAME, IntSerializer.INSTANCE, StringSerializer.INSTANCE));
        this.unionElementsState = functionInitializationContext.getOperatorStateStore().getUnionListState(new ListStateDescriptor(UNION_STATE_NAME, IntSerializer.INSTANCE));
        if (!functionInitializationContext.isRestored()) {
            Preconditions.checkState(!this.lastNumSubtasksBroadcastState.iterator().hasNext());
            Preconditions.checkState(!this.broadcastElementsState.iterator().hasNext());
            Preconditions.checkState(!((Iterable) this.unionElementsState.get()).iterator().hasNext());
            return;
        }
        Integer num = (Integer) this.lastNumSubtasksBroadcastState.get(LAST_NUM_SUBTASKS_STATE_KEY);
        Preconditions.checkState(num != null);
        HashSet hashSet = new HashSet();
        for (int i = 0; i < num.intValue(); i++) {
            hashSet.add(Integer.valueOf(i));
        }
        for (Map.Entry entry : this.broadcastElementsState.entries()) {
            int intValue = ((Integer) entry.getKey()).intValue();
            Preconditions.checkState(hashSet.remove(Integer.valueOf(intValue)), "Unexpected keys in restored broadcast state.");
            Preconditions.checkState(((String) entry.getValue()).equals(getBroadcastStateEntryValue(intValue)), "Incorrect value in restored broadcast state.");
        }
        Preconditions.checkState(hashSet.size() == 0, "Missing keys in restored broadcast state.");
        for (int i2 = 0; i2 < num.intValue(); i2++) {
            hashSet.add(Integer.valueOf(i2));
        }
        Iterator it = ((Iterable) this.unionElementsState.get()).iterator();
        while (it.hasNext()) {
            Preconditions.checkState(hashSet.remove((Integer) it.next()), "Unexpected element in restored union state.");
        }
        Preconditions.checkState(hashSet.size() == 0, "Missing elements in restored union state.");
    }

    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
        int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
        int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
        this.lastNumSubtasksBroadcastState.clear();
        this.lastNumSubtasksBroadcastState.put(LAST_NUM_SUBTASKS_STATE_KEY, Integer.valueOf(numberOfParallelSubtasks));
        this.broadcastElementsState.clear();
        for (int i = 0; i < numberOfParallelSubtasks; i++) {
            this.broadcastElementsState.put(Integer.valueOf(i), getBroadcastStateEntryValue(i));
        }
        this.unionElementsState.clear();
        this.unionElementsState.add(Integer.valueOf(indexOfThisSubtask));
    }

    private String getBroadcastStateEntryValue(int i) {
        return BROADCAST_STATE_ENTRY_VALUE_PREFIX + i;
    }
}
