/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kudu.client;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.kudu.ColumnSchema;
import org.apache.kudu.Schema;
import org.apache.kudu.client.AbstractKuduScannerBuilder;
import org.apache.kudu.client.AsyncKuduClient;
import org.apache.kudu.client.Bytes;
import org.apache.kudu.client.KeyEncoder;
import org.apache.kudu.client.KuduPredicate;
import org.apache.kudu.client.PartialRow;
import org.apache.kudu.client.Partition;
import org.apache.kudu.client.PartitionSchema;
import org.apache.kudu.shaded.com.google.common.base.Preconditions;
import org.apache.kudu.shaded.com.google.common.collect.ImmutableList;
import org.apache.kudu.util.ByteVec;
import org.apache.kudu.util.Pair;
import org.apache.yetus.audience.InterfaceAudience;

@InterfaceAudience.Private
@NotThreadSafe
public class PartitionPruner {
    private final Deque<Pair<byte[], byte[]>> rangePartitions;

    private PartitionPruner(Deque<Pair<byte[], byte[]>> rangePartitions) {
        this.rangePartitions = rangePartitions;
    }

    public int numRangesRemainingForTests() {
        return this.rangePartitions.size();
    }

    private static PartitionPruner empty() {
        return new PartitionPruner(new ArrayDeque<Pair<byte[], byte[]>>());
    }

    public static PartitionPruner create(AbstractKuduScannerBuilder<?, ?> scanner) {
        Schema schema = scanner.table.getSchema();
        PartitionSchema partitionSchema = scanner.table.getPartitionSchema();
        PartitionSchema.RangeSchema rangeSchema = partitionSchema.getRangeSchema();
        Map<String, KuduPredicate> predicates = scanner.predicates;
        if (scanner.upperBoundPrimaryKey.length > 0 && Bytes.memcmp(scanner.lowerBoundPrimaryKey, scanner.upperBoundPrimaryKey) >= 0) {
            return PartitionPruner.empty();
        }
        for (KuduPredicate predicate : predicates.values()) {
            if (predicate.getType() != KuduPredicate.PredicateType.NONE) continue;
            return PartitionPruner.empty();
        }
        byte[] rangeLowerBound = PartitionPruner.pushPredsIntoLowerBoundRangeKey(schema, rangeSchema, predicates);
        byte[] rangeUpperBound = PartitionPruner.pushPredsIntoUpperBoundRangeKey(schema, rangeSchema, predicates);
        if (partitionSchema.isSimpleRangePartitioning()) {
            if (Bytes.memcmp(rangeLowerBound, scanner.lowerBoundPrimaryKey) < 0) {
                rangeLowerBound = scanner.lowerBoundPrimaryKey;
            }
            if (scanner.upperBoundPrimaryKey.length > 0 && (rangeUpperBound.length == 0 || Bytes.memcmp(rangeUpperBound, scanner.upperBoundPrimaryKey) > 0)) {
                rangeUpperBound = scanner.upperBoundPrimaryKey;
            }
        }
        List<PartitionSchema.EncodedRangeBoundsWithHashSchema> preliminaryRanges = PartitionPruner.splitIntoHashSpecificRanges(rangeLowerBound, rangeUpperBound, partitionSchema);
        ArrayList<Pair<byte[], byte[]>> partitionKeyRangeBytes = new ArrayList<Pair<byte[], byte[]>>();
        for (PartitionSchema.EncodedRangeBoundsWithHashSchema preliminaryRange : preliminaryRanges) {
            List<PartitionSchema.HashBucketSchema> hashBucketSchemas = preliminaryRange.hashSchemas;
            ArrayList<BitSet> hashComponents = new ArrayList<BitSet>(hashBucketSchemas.size());
            for (PartitionSchema.HashBucketSchema hashSchema : hashBucketSchemas) {
                hashComponents.add(PartitionPruner.pruneHashComponent(schema, hashSchema, predicates));
            }
            int constrainedIndex = 0;
            if (preliminaryRange.lower.length > 0 || preliminaryRange.upper.length > 0) {
                constrainedIndex = hashBucketSchemas.size();
            } else {
                for (int i = hashComponents.size(); i > 0; --i) {
                    int numBuckets = hashBucketSchemas.get(i - 1).getNumBuckets();
                    BitSet bitSet = (BitSet)hashComponents.get(i - 1);
                    if (bitSet.nextClearBit(0) >= numBuckets) continue;
                    constrainedIndex = i;
                    break;
                }
            }
            ArrayList<Pair<ByteVec, ByteVec>> partitionKeyRanges = new ArrayList<Pair<ByteVec, ByteVec>>();
            partitionKeyRanges.add(new Pair<ByteVec, ByteVec>(ByteVec.create(), ByteVec.create()));
            for (int hashIdx = 0; hashIdx < constrainedIndex; ++hashIdx) {
                boolean bl = hashIdx + 1 == constrainedIndex && preliminaryRange.upper.length == 0;
                BitSet hashBuckets = (BitSet)hashComponents.get(hashIdx);
                ArrayList<Pair<ByteVec, ByteVec>> newPartitionKeyRanges = new ArrayList<Pair<ByteVec, ByteVec>>(partitionKeyRanges.size() * hashBuckets.cardinality());
                for (Pair pair : partitionKeyRanges) {
                    int bucket = hashBuckets.nextSetBit(0);
                    while (bucket != -1) {
                        int bucketUpper = bl ? bucket + 1 : bucket;
                        ByteVec lower = ((ByteVec)pair.getFirst()).clone();
                        ByteVec upper = ((ByteVec)pair.getFirst()).clone();
                        KeyEncoder.encodeHashBucket(bucket, lower);
                        KeyEncoder.encodeHashBucket(bucketUpper, upper);
                        newPartitionKeyRanges.add(new Pair<ByteVec, ByteVec>(lower, upper));
                        bucket = hashBuckets.nextSetBit(bucket + 1);
                    }
                }
                partitionKeyRanges = newPartitionKeyRanges;
            }
            for (Pair pair : partitionKeyRanges) {
                ((ByteVec)pair.getFirst()).append(preliminaryRange.lower);
                ((ByteVec)pair.getSecond()).append(preliminaryRange.upper);
            }
            for (Pair pair : partitionKeyRanges) {
                byte[] lower = ((ByteVec)pair.getFirst()).toArray();
                byte[] upper = ((ByteVec)pair.getSecond()).toArray();
                assert (upper.length == 0 || Bytes.memcmp(lower, upper) < 0);
                if (scanner.lowerBoundPartitionKey.length > 0 && (lower.length == 0 || Bytes.memcmp(lower, scanner.lowerBoundPartitionKey) < 0)) {
                    lower = scanner.lowerBoundPartitionKey;
                }
                if (scanner.upperBoundPartitionKey.length > 0 && (upper.length == 0 || Bytes.memcmp(upper, scanner.upperBoundPartitionKey) > 0)) {
                    upper = scanner.upperBoundPartitionKey;
                }
                if (upper.length != 0 && Bytes.memcmp(lower, upper) >= 0) continue;
                partitionKeyRangeBytes.add(new Pair<byte[], byte[]>(lower, upper));
            }
        }
        Collections.sort(partitionKeyRangeBytes, (lhs, rhs) -> Bytes.memcmp((byte[])lhs.getFirst(), (byte[])rhs.getFirst()));
        return new PartitionPruner(new ArrayDeque<Pair<byte[], byte[]>>(partitionKeyRangeBytes));
    }

    public boolean hasMorePartitionKeyRanges() {
        return !this.rangePartitions.isEmpty();
    }

    public byte[] nextPartitionKey() {
        return this.rangePartitions.getFirst().getFirst();
    }

    public Pair<byte[], byte[]> nextPartitionKeyRange() {
        return this.rangePartitions.getFirst();
    }

    public void removePartitionKeyRange(byte[] upperBound) {
        Pair<byte[], byte[]> range;
        if (upperBound.length == 0) {
            this.rangePartitions.clear();
            return;
        }
        while (!this.rangePartitions.isEmpty() && Bytes.memcmp(upperBound, (range = this.rangePartitions.getFirst()).getFirst()) > 0) {
            this.rangePartitions.removeFirst();
            if (range.getSecond().length != 0 && Bytes.memcmp(upperBound, range.getSecond()) >= 0) continue;
            this.rangePartitions.addFirst(new Pair<byte[], byte[]>(upperBound, range.getSecond()));
            break;
        }
    }

    boolean shouldPruneForTests(Partition partition) {
        for (Pair<byte[], byte[]> range : this.rangePartitions) {
            if (range.getSecond().length > 0 && Bytes.memcmp(range.getSecond(), partition.getPartitionKeyStart()) <= 0) continue;
            return partition.getPartitionKeyEnd().length > 0 && Bytes.memcmp(partition.getPartitionKeyEnd(), range.getFirst()) <= 0;
        }
        return true;
    }

    static List<Integer> idsToIndexesForTest(Schema schema, List<Integer> ids) {
        return PartitionPruner.idsToIndexes(schema, ids);
    }

    private static List<Integer> idsToIndexes(Schema schema, List<Integer> ids) {
        ArrayList<Integer> indexes = new ArrayList<Integer>(ids.size());
        for (int id : ids) {
            indexes.add(schema.getColumnIndex(id));
        }
        return indexes;
    }

    private static boolean incrementKey(PartialRow row, List<Integer> keyIndexes) {
        for (int i = keyIndexes.size() - 1; i >= 0; --i) {
            if (!row.incrementColumn(keyIndexes.get(i))) continue;
            return true;
        }
        return false;
    }

    private static byte[] pushPredsIntoLowerBoundRangeKey(Schema schema, PartitionSchema.RangeSchema rangeSchema, Map<String, KuduPredicate> predicates) {
        int idx;
        ColumnSchema column;
        KuduPredicate predicate;
        PartialRow row = schema.newPartialRow();
        int pushedPredicates = 0;
        List<Integer> rangePartitionColumnIdxs = PartitionPruner.idsToIndexes(schema, rangeSchema.getColumnIds());
        Iterator<Integer> iterator = rangePartitionColumnIdxs.iterator();
        block6: while (iterator.hasNext() && (predicate = predicates.get((column = schema.getColumnByIndex(idx = iterator.next().intValue())).getName())) != null) {
            switch (predicate.getType()) {
                case RANGE: {
                    if (predicate.getLower() == null) break block6;
                }
                case EQUALITY: {
                    row.setRaw(idx, predicate.getLower());
                    ++pushedPredicates;
                    break;
                }
                case IS_NOT_NULL: {
                    break block6;
                }
                case IN_LIST: {
                    row.setRaw(idx, predicate.getInListValues()[0]);
                    ++pushedPredicates;
                    break;
                }
                default: {
                    throw new IllegalArgumentException(String.format("unexpected predicate type can not be pushed into key: %s", predicate));
                }
            }
        }
        if (pushedPredicates == 0) {
            return AsyncKuduClient.EMPTY_ARRAY;
        }
        ListIterator<Integer> remainingIdxs = rangePartitionColumnIdxs.listIterator(pushedPredicates);
        while (remainingIdxs.hasNext()) {
            row.setMin((Integer)remainingIdxs.next());
        }
        return KeyEncoder.encodeRangePartitionKey(row, rangeSchema);
    }

    private static byte[] pushPredsIntoUpperBoundRangeKey(Schema schema, PartitionSchema.RangeSchema rangeSchema, Map<String, KuduPredicate> predicates) {
        int idx;
        ColumnSchema column;
        KuduPredicate predicate;
        PartialRow row = schema.newPartialRow();
        int pushedPredicates = 0;
        KuduPredicate finalPredicate = null;
        List<Integer> rangePartitionColumnIdxs = PartitionPruner.idsToIndexes(schema, rangeSchema.getColumnIds());
        Iterator<Integer> iterator = rangePartitionColumnIdxs.iterator();
        block6: while (iterator.hasNext() && (predicate = predicates.get((column = schema.getColumnByIndex(idx = iterator.next().intValue())).getName())) != null) {
            switch (predicate.getType()) {
                case EQUALITY: {
                    row.setRaw(idx, predicate.getLower());
                    ++pushedPredicates;
                    finalPredicate = predicate;
                    break;
                }
                case RANGE: {
                    if (predicate.getUpper() == null) break block6;
                    row.setRaw(idx, predicate.getUpper());
                    ++pushedPredicates;
                    finalPredicate = predicate;
                    break block6;
                }
                case IS_NOT_NULL: {
                    break block6;
                }
                case IN_LIST: {
                    byte[][] values = predicate.getInListValues();
                    row.setRaw(idx, values[values.length - 1]);
                    ++pushedPredicates;
                    finalPredicate = predicate;
                    break;
                }
                default: {
                    throw new IllegalArgumentException(String.format("unexpected predicate type can not be pushed into key: %s", predicate));
                }
            }
        }
        if (pushedPredicates == 0) {
            return AsyncKuduClient.EMPTY_ARRAY;
        }
        if (!(finalPredicate.getType() != KuduPredicate.PredicateType.EQUALITY && finalPredicate.getType() != KuduPredicate.PredicateType.IN_LIST || PartitionPruner.incrementKey(row, rangePartitionColumnIdxs.subList(0, pushedPredicates)))) {
            return AsyncKuduClient.EMPTY_ARRAY;
        }
        ListIterator<Integer> remainingIdxs = rangePartitionColumnIdxs.listIterator(pushedPredicates);
        while (remainingIdxs.hasNext()) {
            row.setMin((Integer)remainingIdxs.next());
        }
        return KeyEncoder.encodeRangePartitionKey(row, rangeSchema);
    }

    static List<PartitionSchema.EncodedRangeBoundsWithHashSchema> splitIntoHashSpecificRanges(byte[] scanLowerBound, byte[] scanUpperBound, PartitionSchema ps) {
        List<PartitionSchema.EncodedRangeBoundsWithHashSchema> ranges = ps.getEncodedRangesWithHashSchemas();
        List<PartitionSchema.HashBucketSchema> tableWideHashSchema = ps.getHashBucketSchemas();
        if (ranges.isEmpty()) {
            return ImmutableList.of(new PartitionSchema.EncodedRangeBoundsWithHashSchema(scanLowerBound, scanUpperBound, tableWideHashSchema));
        }
        byte[] rangesLowerBound = ranges.get((int)0).lower;
        byte[] rangesUpperBound = ranges.get((int)(ranges.size() - 1)).upper;
        if (scanUpperBound.length != 0 && Bytes.memcmp(scanUpperBound, rangesLowerBound) <= 0 || scanLowerBound.length != 0 && rangesUpperBound.length != 0 && Bytes.memcmp(rangesUpperBound, scanLowerBound) <= 0) {
            return ImmutableList.of(new PartitionSchema.EncodedRangeBoundsWithHashSchema(scanLowerBound, scanUpperBound, tableWideHashSchema));
        }
        int curIdx = -1;
        for (int idx = 0; idx < ranges.size(); ++idx) {
            PartitionSchema.EncodedRangeBoundsWithHashSchema range = ranges.get(idx);
            if (curIdx >= 0 || range.upper.length != 0 && Bytes.memcmp(range.upper, scanLowerBound) <= 0) continue;
            curIdx = idx;
        }
        Preconditions.checkState(curIdx >= 0);
        Preconditions.checkState(curIdx < ranges.size());
        byte[] curPoint = scanLowerBound;
        ArrayList<PartitionSchema.EncodedRangeBoundsWithHashSchema> result = new ArrayList<PartitionSchema.EncodedRangeBoundsWithHashSchema>();
        while (curIdx < ranges.size() && (Bytes.memcmp(curPoint, scanUpperBound) < 0 || scanUpperBound.length == 0)) {
            PartitionSchema.EncodedRangeBoundsWithHashSchema curRange = ranges.get(curIdx);
            if (Bytes.memcmp(curPoint, curRange.lower) < 0) {
                byte[] upperBound = scanUpperBound.length == 0 ? curRange.lower : (Bytes.memcmp(curRange.lower, scanUpperBound) < 0 ? curRange.lower : scanUpperBound);
                result.add(new PartitionSchema.EncodedRangeBoundsWithHashSchema(curPoint, upperBound, tableWideHashSchema));
            } else if (Bytes.memcmp(curPoint, curRange.lower) == 0) {
                if (curRange.upper.length != 0 && Bytes.memcmp(curRange.upper, scanUpperBound) <= 0 || scanUpperBound.length == 0) {
                    result.add(curRange);
                } else {
                    result.add(new PartitionSchema.EncodedRangeBoundsWithHashSchema(curPoint, scanUpperBound, curRange.hashSchemas));
                }
                ++curIdx;
            } else {
                if (scanUpperBound.length != 0 && Bytes.memcmp(scanUpperBound, curRange.upper) <= 0 || curRange.upper.length == 0) {
                    result.add(new PartitionSchema.EncodedRangeBoundsWithHashSchema(curPoint, scanUpperBound, curRange.hashSchemas));
                } else {
                    result.add(new PartitionSchema.EncodedRangeBoundsWithHashSchema(curPoint, curRange.upper, curRange.hashSchemas));
                }
                ++curIdx;
            }
            Preconditions.checkState(!result.isEmpty());
            curPoint = ((PartitionSchema.EncodedRangeBoundsWithHashSchema)result.get((int)(result.size() - 1))).upper;
        }
        Preconditions.checkState(!result.isEmpty());
        byte[] rangesUpperBound2 = ((PartitionSchema.EncodedRangeBoundsWithHashSchema)result.get((int)(result.size() - 1))).upper;
        if (Bytes.memcmp(rangesUpperBound2, scanUpperBound) != 0) {
            Preconditions.checkState(Bytes.memcmp(curPoint, rangesUpperBound2) == 0);
            result.add(new PartitionSchema.EncodedRangeBoundsWithHashSchema(curPoint, scanUpperBound, tableWideHashSchema));
        }
        return result;
    }

    static BitSet pruneHashComponentV2ForTest(Schema schema, PartitionSchema.HashBucketSchema hashSchema, Map<String, KuduPredicate> predicates) {
        return PartitionPruner.pruneHashComponent(schema, hashSchema, predicates);
    }

    private static BitSet pruneHashComponent(Schema schema, PartitionSchema.HashBucketSchema hashSchema, Map<String, KuduPredicate> predicates) {
        BitSet hashBuckets = new BitSet(hashSchema.getNumBuckets());
        List<Integer> columnIdxs = PartitionPruner.idsToIndexes(schema, hashSchema.getColumnIds());
        ArrayList<List<byte[]>> predicateValueList = new ArrayList<List<byte[]>>();
        for (int idx : columnIdxs) {
            ColumnSchema column = schema.getColumnByIndex(idx);
            KuduPredicate predicate = predicates.get(column.getName());
            if (predicate == null || predicate.getType() != KuduPredicate.PredicateType.EQUALITY && predicate.getType() != KuduPredicate.PredicateType.IN_LIST) {
                hashBuckets.set(0, hashSchema.getNumBuckets());
                return hashBuckets;
            }
            List<Object> predicateValues = predicate.getType() == KuduPredicate.PredicateType.EQUALITY ? Collections.singletonList(predicate.getLower()) : Arrays.asList(predicate.getInListValues());
            predicateValueList.add(predicateValues);
        }
        ArrayList<byte[]> valuesCombination = new ArrayList<byte[]>();
        PartitionPruner.computeHashBuckets(schema, hashSchema, hashBuckets, columnIdxs, predicateValueList, valuesCombination);
        return hashBuckets;
    }

    private static void computeHashBuckets(Schema schema, PartitionSchema.HashBucketSchema hashSchema, BitSet hashBuckets, List<Integer> columnIdxs, List<List<byte[]>> predicateValueList, List<byte[]> valuesCombination) {
        if (hashBuckets.cardinality() == hashSchema.getNumBuckets()) {
            return;
        }
        int level = valuesCombination.size();
        if (level == columnIdxs.size()) {
            PartialRow row = schema.newPartialRow();
            for (int i = 0; i < valuesCombination.size(); ++i) {
                row.setRaw(columnIdxs.get(i), valuesCombination.get(i));
            }
            int hash = KeyEncoder.getHashBucket(row, hashSchema);
            hashBuckets.set(hash);
            return;
        }
        for (int i = 0; i < predicateValueList.get(level).size(); ++i) {
            valuesCombination.add(predicateValueList.get(level).get(i));
            PartitionPruner.computeHashBuckets(schema, hashSchema, hashBuckets, columnIdxs, predicateValueList, valuesCombination);
            valuesCombination.remove(valuesCombination.size() - 1);
        }
    }
}

