package org.apache.lucene.util.hnsw;

import java.io.IOException;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.VectorUtil;

/* loaded from: input_file:WEB-INF/lib/lucene-core-9.4.1.jar:org/apache/lucene/util/hnsw/HnswGraphSearcher.class */
public class HnswGraphSearcher<T> {
    private final VectorSimilarityFunction similarityFunction;
    private final VectorEncoding vectorEncoding;
    private final NeighborQueue candidates;
    private BitSet visited;
    static final /* synthetic */ boolean $assertionsDisabled;

    public HnswGraphSearcher(VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, NeighborQueue neighborQueue, BitSet bitSet) {
        this.vectorEncoding = vectorEncoding;
        this.similarityFunction = vectorSimilarityFunction;
        this.candidates = neighborQueue;
        this.visited = bitSet;
    }

    public static NeighborQueue search(float[] fArr, int i, RandomAccessVectorValues randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, HnswGraph hnswGraph, Bits bits, int i2) throws IOException {
        if (fArr.length != randomAccessVectorValues.dimension()) {
            throw new IllegalArgumentException("vector query dimension: " + fArr.length + " differs from field dimension: " + randomAccessVectorValues.dimension());
        }
        if (vectorEncoding == VectorEncoding.BYTE) {
            return search(VectorUtil.toBytesRef(fArr), i, randomAccessVectorValues, vectorEncoding, vectorSimilarityFunction, hnswGraph, bits, i2);
        }
        HnswGraphSearcher hnswGraphSearcher = new HnswGraphSearcher(vectorEncoding, vectorSimilarityFunction, new NeighborQueue(i, true), new SparseFixedBitSet(randomAccessVectorValues.size()));
        int[] iArr = {hnswGraph.entryNode()};
        int i3 = 0;
        for (int numLevels = hnswGraph.numLevels() - 1; numLevels >= 1; numLevels--) {
            NeighborQueue searchLevel = hnswGraphSearcher.searchLevel(fArr, 1, numLevels, iArr, randomAccessVectorValues, hnswGraph, null, i2);
            i3 += searchLevel.visitedCount();
            i2 -= searchLevel.visitedCount();
            if (searchLevel.incomplete()) {
                searchLevel.setVisitedCount(i3);
                return searchLevel;
            }
            iArr[0] = searchLevel.pop();
        }
        NeighborQueue searchLevel2 = hnswGraphSearcher.searchLevel(fArr, i, 0, iArr, randomAccessVectorValues, hnswGraph, bits, i2);
        searchLevel2.setVisitedCount(searchLevel2.visitedCount() + i3);
        return searchLevel2;
    }

    private static NeighborQueue search(BytesRef bytesRef, int i, RandomAccessVectorValues randomAccessVectorValues, VectorEncoding vectorEncoding, VectorSimilarityFunction vectorSimilarityFunction, HnswGraph hnswGraph, Bits bits, int i2) throws IOException {
        HnswGraphSearcher hnswGraphSearcher = new HnswGraphSearcher(vectorEncoding, vectorSimilarityFunction, new NeighborQueue(i, true), new SparseFixedBitSet(randomAccessVectorValues.size()));
        int[] iArr = {hnswGraph.entryNode()};
        int i3 = 0;
        for (int numLevels = hnswGraph.numLevels() - 1; numLevels >= 1; numLevels--) {
            NeighborQueue searchLevel = hnswGraphSearcher.searchLevel(bytesRef, 1, numLevels, iArr, randomAccessVectorValues, hnswGraph, null, i2);
            i3 += searchLevel.visitedCount();
            i2 -= searchLevel.visitedCount();
            if (searchLevel.incomplete()) {
                searchLevel.setVisitedCount(i3);
                return searchLevel;
            }
            iArr[0] = searchLevel.pop();
        }
        NeighborQueue searchLevel2 = hnswGraphSearcher.searchLevel(bytesRef, i, 0, iArr, randomAccessVectorValues, hnswGraph, bits, i2);
        searchLevel2.setVisitedCount(searchLevel2.visitedCount() + i3);
        return searchLevel2;
    }

    public NeighborQueue searchLevel(T t, int i, int i2, int[] iArr, RandomAccessVectorValues randomAccessVectorValues, HnswGraph hnswGraph) throws IOException {
        return searchLevel(t, i, i2, iArr, randomAccessVectorValues, hnswGraph, null, Integer.MAX_VALUE);
    }

    private NeighborQueue searchLevel(T t, int i, int i2, int[] iArr, RandomAccessVectorValues randomAccessVectorValues, HnswGraph hnswGraph, Bits bits, int i3) throws IOException {
        int size = hnswGraph.size();
        NeighborQueue neighborQueue = new NeighborQueue(i, false);
        prepareScratchState(randomAccessVectorValues.size());
        int i4 = 0;
        int length = iArr.length;
        int i5 = 0;
        while (true) {
            if (i5 >= length) {
                break;
            }
            int i6 = iArr[i5];
            if (!this.visited.getAndSet(i6)) {
                if (i4 >= i3) {
                    neighborQueue.markIncomplete();
                    break;
                }
                float compare = compare(t, randomAccessVectorValues, i6);
                i4++;
                this.candidates.add(i6, compare);
                if (bits == null || bits.get(i6)) {
                    neighborQueue.add(i6, compare);
                }
            }
            i5++;
        }
        float f = Float.NEGATIVE_INFINITY;
        if (neighborQueue.size() >= i) {
            f = neighborQueue.topScore();
        }
        while (this.candidates.size() > 0 && !neighborQueue.incomplete() && this.candidates.topScore() >= f) {
            hnswGraph.seek(i2, this.candidates.pop());
            while (true) {
                int nextNeighbor = hnswGraph.nextNeighbor();
                if (nextNeighbor == Integer.MAX_VALUE) {
                    break;
                }
                if (!$assertionsDisabled && nextNeighbor >= size) {
                    throw new AssertionError("friendOrd=" + nextNeighbor + "; size=" + size);
                }
                if (!this.visited.getAndSet(nextNeighbor)) {
                    if (i4 >= i3) {
                        neighborQueue.markIncomplete();
                        break;
                    }
                    float compare2 = compare(t, randomAccessVectorValues, nextNeighbor);
                    i4++;
                    if (compare2 >= f) {
                        this.candidates.add(nextNeighbor, compare2);
                        if (bits == null || bits.get(nextNeighbor)) {
                            if (neighborQueue.insertWithOverflow(nextNeighbor, compare2) && neighborQueue.size() >= i) {
                                f = neighborQueue.topScore();
                            }
                        }
                    }
                }
            }
        }
        while (neighborQueue.size() > i) {
            neighborQueue.pop();
        }
        neighborQueue.setVisitedCount(i4);
        return neighborQueue;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private float compare(T t, RandomAccessVectorValues randomAccessVectorValues, int i) throws IOException {
        return this.vectorEncoding == VectorEncoding.BYTE ? this.similarityFunction.compare((BytesRef) t, randomAccessVectorValues.binaryValue(i)) : this.similarityFunction.compare((float[]) t, randomAccessVectorValues.vectorValue(i));
    }

    private void prepareScratchState(int i) {
        this.candidates.clear();
        if (this.visited.length() < i) {
            this.visited = FixedBitSet.ensureCapacity((FixedBitSet) this.visited, i);
        }
        this.visited.clear(0, this.visited.length());
    }

    static {
        $assertionsDisabled = !HnswGraphSearcher.class.desiredAssertionStatus();
    }
}
