/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dss.shadelib.org.apache.lucene.codecs.hnsw;

import com.dataiku.dss.shadelib.org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import com.dataiku.dss.shadelib.org.apache.lucene.index.VectorSimilarityFunction;
import com.dataiku.dss.shadelib.org.apache.lucene.util.ArrayUtil;
import com.dataiku.dss.shadelib.org.apache.lucene.util.VectorUtil;
import com.dataiku.dss.shadelib.org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import com.dataiku.dss.shadelib.org.apache.lucene.util.hnsw.RandomVectorScorer;
import com.dataiku.dss.shadelib.org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import com.dataiku.dss.shadelib.org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import com.dataiku.dss.shadelib.org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import com.dataiku.dss.shadelib.org.apache.lucene.util.quantization.ScalarQuantizer;
import java.io.IOException;

public class ScalarQuantizedVectorScorer
implements FlatVectorsScorer {
    private final FlatVectorsScorer nonQuantizedDelegate;

    public static float quantizeQuery(float[] query, byte[] quantizedQuery, VectorSimilarityFunction similarityFunction, ScalarQuantizer scalarQuantizer) {
        float[] processedQuery;
        switch (similarityFunction) {
            case EUCLIDEAN: 
            case DOT_PRODUCT: 
            case MAXIMUM_INNER_PRODUCT: {
                processedQuery = query;
                break;
            }
            case COSINE: {
                float[] queryCopy = ArrayUtil.copyArray(query);
                VectorUtil.l2normalize(queryCopy);
                processedQuery = queryCopy;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported similarity function: " + similarityFunction);
            }
        }
        return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction);
    }

    public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
        this.nonQuantizedDelegate = flatVectorsScorer;
    }

    @Override
    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) throws IOException {
        if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) {
            RandomAccessQuantizedByteVectorValues quantizedByteVectorValues = (RandomAccessQuantizedByteVectorValues)vectorValues;
            return new ScalarQuantizedRandomVectorScorerSupplier(similarityFunction, quantizedByteVectorValues.getScalarQuantizer(), quantizedByteVectorValues);
        }
        return this.nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues, float[] target) throws IOException {
        if (vectorValues instanceof RandomAccessQuantizedByteVectorValues) {
            final RandomAccessQuantizedByteVectorValues quantizedByteVectorValues = (RandomAccessQuantizedByteVectorValues)vectorValues;
            ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
            final byte[] targetBytes = new byte[target.length];
            final float offsetCorrection = ScalarQuantizedVectorScorer.quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer);
            final ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
            return new RandomVectorScorer.AbstractRandomVectorScorer(quantizedByteVectorValues){

                @Override
                public float score(int node) throws IOException {
                    byte[] nodeVector = quantizedByteVectorValues.vectorValue(node);
                    float nodeOffset = quantizedByteVectorValues.getScoreCorrectionConstant(node);
                    return scalarQuantizedVectorSimilarity.score(targetBytes, offsetCorrection, nodeVector, nodeOffset);
                }
            };
        }
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues, byte[] target) throws IOException {
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    public String toString() {
        return "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + this.nonQuantizedDelegate + ")";
    }

    public static class ScalarQuantizedRandomVectorScorerSupplier
    implements RandomVectorScorerSupplier {
        private final RandomAccessQuantizedByteVectorValues values;
        private final ScalarQuantizedVectorSimilarity similarity;
        private final VectorSimilarityFunction vectorSimilarityFunction;

        public ScalarQuantizedRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, ScalarQuantizer scalarQuantizer, RandomAccessQuantizedByteVectorValues values2) {
            this.similarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits());
            this.values = values2;
            this.vectorSimilarityFunction = similarityFunction;
        }

        private ScalarQuantizedRandomVectorScorerSupplier(ScalarQuantizedVectorSimilarity similarity, VectorSimilarityFunction vectorSimilarityFunction, RandomAccessQuantizedByteVectorValues values2) {
            this.similarity = similarity;
            this.values = values2;
            this.vectorSimilarityFunction = vectorSimilarityFunction;
        }

        @Override
        public RandomVectorScorer scorer(int ord) throws IOException {
            final RandomAccessQuantizedByteVectorValues vectorsCopy = this.values.copy();
            final byte[] queryVector = this.values.vectorValue(ord);
            final float queryOffset = this.values.getScoreCorrectionConstant(ord);
            return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy){

                @Override
                public float score(int node) throws IOException {
                    byte[] nodeVector = vectorsCopy.vectorValue(node);
                    float nodeOffset = vectorsCopy.getScoreCorrectionConstant(node);
                    return similarity.score(queryVector, queryOffset, nodeVector, nodeOffset);
                }
            };
        }

        @Override
        public RandomVectorScorerSupplier copy() throws IOException {
            return new ScalarQuantizedRandomVectorScorerSupplier(this.similarity, this.vectorSimilarityFunction, this.values.copy());
        }

        public String toString() {
            return "ScalarQuantizedRandomVectorScorerSupplier(vectorSimilarityFunction=" + this.vectorSimilarityFunction + ")";
        }
    }
}

