/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.ConjunctionDISI;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TimeLimitingKnnCollectorManager;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.VectorSimilarityCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.Bits;

abstract class AbstractVectorSimilarityQuery
extends Query {
    static final KnnSearchStrategy.Hnsw DEFAULT_STRATEGY = new KnnSearchStrategy.Hnsw(0);
    protected final String field;
    protected final float traversalSimilarity;
    protected final float resultSimilarity;
    protected final Query filter;

    AbstractVectorSimilarityQuery(String field, float traversalSimilarity, float resultSimilarity, Query filter) {
        if (traversalSimilarity > resultSimilarity) {
            throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity");
        }
        this.field = Objects.requireNonNull(field, "field");
        this.traversalSimilarity = traversalSimilarity;
        this.resultSimilarity = resultSimilarity;
        this.filter = filter;
    }

    protected KnnCollectorManager getKnnCollectorManager() {
        return (visitLimit, searchStrategy, context2) -> new VectorSimilarityCollector(this.traversalSimilarity, this.resultSimilarity, visitLimit);
    }

    abstract VectorScorer createVectorScorer(LeafReaderContext var1) throws IOException;

    protected abstract TopDocs approximateSearch(LeafReaderContext var1, AcceptDocs var2, int var3, KnnCollectorManager var4) throws IOException;

    @Override
    public Weight createWeight(final IndexSearcher searcher, ScoreMode scoreMode, final float boost) throws IOException {
        return new Weight(this){
            final Weight filterWeight;
            final QueryTimeout queryTimeout;
            final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager;
            {
                super(query2);
                this.filterWeight = AbstractVectorSimilarityQuery.this.filter == null ? null : searcher.createWeight(searcher.rewrite(AbstractVectorSimilarityQuery.this.filter), ScoreMode.COMPLETE_NO_SCORES, 1.0f);
                this.queryTimeout = searcher.getTimeout();
                this.timeLimitingKnnCollectorManager = new TimeLimitingKnnCollectorManager(AbstractVectorSimilarityQuery.this.getKnnCollectorManager(), this.queryTimeout);
            }

            @Override
            public Explanation explain(LeafReaderContext context2, int doc) throws IOException {
                Scorer filterScorer;
                if (this.filterWeight != null && ((filterScorer = this.filterWeight.scorer(context2)) == null || filterScorer.iterator().advance(doc) > doc)) {
                    return Explanation.noMatch("Doc does not match the filter", new Explanation[0]);
                }
                VectorScorer scorer = AbstractVectorSimilarityQuery.this.createVectorScorer(context2);
                if (scorer == null) {
                    return Explanation.noMatch("Not indexed as the correct vector field", new Explanation[0]);
                }
                DocIdSetIterator iterator = scorer.iterator();
                int docId = iterator.advance(doc);
                if (docId == doc) {
                    float score = scorer.score();
                    if (score >= AbstractVectorSimilarityQuery.this.resultSimilarity) {
                        return Explanation.match((Number)Float.valueOf(boost * score), "Score above threshold", new Explanation[0]);
                    }
                    return Explanation.noMatch("Score below threshold", new Explanation[0]);
                }
                return Explanation.noMatch("No vector found for doc", new Explanation[0]);
            }

            @Override
            public ScorerSupplier scorerSupplier(LeafReaderContext context2) throws IOException {
                LeafReader leafReader = context2.reader();
                Bits liveDocs = leafReader.getLiveDocs();
                if (this.filterWeight == null) {
                    TopDocs results = AbstractVectorSimilarityQuery.this.approximateSearch(context2, AcceptDocs.fromLiveDocs(liveDocs, leafReader.maxDoc()), Integer.MAX_VALUE, this.timeLimitingKnnCollectorManager);
                    return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
                }
                AcceptDocs acceptDocs = AcceptDocs.fromIteratorSupplier(() -> {
                    Scorer scorer = this.filterWeight.scorer(context2);
                    if (scorer == null) {
                        return DocIdSetIterator.empty();
                    }
                    return scorer.iterator();
                }, liveDocs, leafReader.maxDoc());
                int cardinality = acceptDocs.cost();
                if (cardinality == 0) {
                    return null;
                }
                TopDocs results = AbstractVectorSimilarityQuery.this.approximateSearch(context2, acceptDocs, cardinality, this.timeLimitingKnnCollectorManager);
                if (results.totalHits.relation() == TotalHits.Relation.EQUAL_TO || this.queryTimeout != null && this.queryTimeout.shouldExit()) {
                    return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
                }
                return VectorSimilarityScorerSupplier.fromAcceptDocs(boost, AbstractVectorSimilarityQuery.this.createVectorScorer(context2), acceptDocs.iterator(), AbstractVectorSimilarityQuery.this.resultSimilarity);
            }

            @Override
            public boolean isCacheable(LeafReaderContext ctx) {
                return true;
            }
        };
    }

    @Override
    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.field)) {
            visitor.visitLeaf(this);
        }
    }

    @Override
    public boolean equals(Object o) {
        return this.sameClassAs(o) && Objects.equals(this.field, ((AbstractVectorSimilarityQuery)o).field) && Float.compare(((AbstractVectorSimilarityQuery)o).traversalSimilarity, this.traversalSimilarity) == 0 && Float.compare(((AbstractVectorSimilarityQuery)o).resultSimilarity, this.resultSimilarity) == 0 && Objects.equals(this.filter, ((AbstractVectorSimilarityQuery)o).filter);
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.field, Float.valueOf(this.traversalSimilarity), Float.valueOf(this.resultSimilarity), this.filter);
    }

    private static class VectorSimilarityScorerSupplier
    extends ScorerSupplier {
        final DocIdSetIterator iterator;
        final float[] cachedScore;

        VectorSimilarityScorerSupplier(DocIdSetIterator iterator, float[] cachedScore) {
            this.iterator = iterator;
            this.cachedScore = cachedScore;
        }

        static VectorSimilarityScorerSupplier fromScoreDocs(final float boost, final ScoreDoc[] scoreDocs) {
            if (scoreDocs.length == 0) {
                return null;
            }
            Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
            final float[] cachedScore = new float[1];
            DocIdSetIterator iterator = new DocIdSetIterator(){
                int index = -1;

                @Override
                public int docID() {
                    if (this.index < 0) {
                        return -1;
                    }
                    if (this.index >= scoreDocs.length) {
                        return Integer.MAX_VALUE;
                    }
                    cachedScore[0] = boost * scoreDocs[this.index].score;
                    return scoreDocs[this.index].doc;
                }

                @Override
                public int nextDoc() {
                    ++this.index;
                    return this.docID();
                }

                @Override
                public int advance(int target) {
                    this.index = Arrays.binarySearch(scoreDocs, new ScoreDoc(target, 0.0f), Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
                    if (this.index < 0) {
                        this.index = -1 - this.index;
                    }
                    return this.docID();
                }

                @Override
                public long cost() {
                    return scoreDocs.length;
                }
            };
            return new VectorSimilarityScorerSupplier(iterator, cachedScore);
        }

        static VectorSimilarityScorerSupplier fromAcceptDocs(final float boost, final VectorScorer scorer, DocIdSetIterator acceptDocs, final float threshold) {
            if (scorer == null) {
                return null;
            }
            final float[] cachedScore = new float[1];
            final DocIdSetIterator vectorIterator = scorer.iterator();
            DocIdSetIterator conjunction = ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptDocs), List.of());
            FilteredDocIdSetIterator iterator = new FilteredDocIdSetIterator(conjunction){

                @Override
                protected boolean match(int doc) throws IOException {
                    assert (doc == vectorIterator.docID());
                    float score = scorer.score();
                    cachedScore[0] = score * boost;
                    return score >= threshold;
                }
            };
            return new VectorSimilarityScorerSupplier(iterator, cachedScore);
        }

        @Override
        public Scorer get(long leadCost) {
            return new Scorer(){

                @Override
                public int docID() {
                    return iterator.docID();
                }

                @Override
                public DocIdSetIterator iterator() {
                    return iterator;
                }

                @Override
                public float getMaxScore(int upTo) {
                    return Float.POSITIVE_INFINITY;
                }

                @Override
                public float score() {
                    return cachedScore[0];
                }
            };
        }

        @Override
        public long cost() {
            return this.iterator.cost();
        }
    }
}

