/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.vec.internal;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.elasticsearch.vec.VectorScorer;
import org.elasticsearch.vec.internal.DotProduct;
import org.elasticsearch.vec.internal.Euclidean;
import org.elasticsearch.vec.internal.IndexInputUtils;
import org.elasticsearch.vec.internal.MaximumInnerProduct;

abstract sealed class AbstractScalarQuantizedVectorScorer
implements VectorScorer
permits DotProduct, Euclidean, MaximumInnerProduct {
    static final VectorSimilarityFunctions DISTANCE_FUNCS = (VectorSimilarityFunctions)NativeAccess.instance().getVectorSimilarityFunctions().orElseThrow(AssertionError::new);
    protected final int dims;
    protected final int maxOrd;
    protected final float scoreCorrectionConstant;
    protected final IndexInput input;
    protected final MemorySegment segment;
    protected final MemorySegment[] segments;
    protected final long offset;
    protected final int chunkSizePower;
    protected final long chunkSizeMask;
    private final ScalarQuantizedVectorSimilarity fallbackScorer;
    static final MethodHandle DOT_PRODUCT = DISTANCE_FUNCS.dotProductHandle();
    static final MethodHandle SQUARE_DISTANCE = DISTANCE_FUNCS.squareDistanceHandle();

    protected AbstractScalarQuantizedVectorScorer(int dims, int maxOrd, float scoreCorrectionConstant, IndexInput input, ScalarQuantizedVectorSimilarity fallbackScorer) {
        this.dims = dims;
        this.maxOrd = maxOrd;
        this.scoreCorrectionConstant = scoreCorrectionConstant;
        this.input = input;
        this.fallbackScorer = fallbackScorer;
        this.segments = IndexInputUtils.segmentArray(input);
        if (this.segments.length == 1) {
            this.segment = this.segments[0];
            this.offset = 0L;
        } else {
            this.segment = null;
            this.offset = IndexInputUtils.offset(input);
        }
        this.chunkSizePower = IndexInputUtils.chunkSizePower(input);
        this.chunkSizeMask = IndexInputUtils.chunkSizeMask(input);
    }

    @Override
    public final int dims() {
        return this.dims;
    }

    @Override
    public final int maxOrd() {
        return this.maxOrd;
    }

    protected final void checkOrdinal(int ord) {
        if (ord < 0 || ord > this.maxOrd) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    protected final float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException {
        this.input.seek(firstByteOffset);
        byte[] a = new byte[this.dims];
        this.input.readBytes(a, 0, a.length);
        float aOffsetValue = Float.intBitsToFloat(this.input.readInt());
        this.input.seek(secondByteOffset);
        byte[] b = new byte[this.dims];
        this.input.readBytes(b, 0, a.length);
        float bOffsetValue = Float.intBitsToFloat(this.input.readInt());
        return this.fallbackScorer.score(a, aOffsetValue, b, bOffsetValue);
    }

    protected final MemorySegment segmentSlice(long pos, int length) {
        if (this.segment != null) {
            if (AbstractScalarQuantizedVectorScorer.checkIndex(pos, this.segment.byteSize() + 1L)) {
                return this.segment.asSlice(pos, length);
            }
        } else {
            int si;
            MemorySegment seg;
            long offset = (pos += this.offset) & this.chunkSizeMask;
            if (AbstractScalarQuantizedVectorScorer.checkIndex(offset + (long)length, (seg = this.segments[si = (int)(pos >> this.chunkSizePower)]).byteSize() + 1L)) {
                return seg.asSlice(offset, length);
            }
        }
        return null;
    }

    static boolean checkIndex(long index, long length) {
        return index >= 0L && index < length;
    }

    static int dotProduct(MemorySegment a, MemorySegment b, int length) {
        try {
            return DOT_PRODUCT.invokeExact(a, b, length);
        }
        catch (Throwable e) {
            if (e instanceof Error) {
                Error err = (Error)e;
                throw err;
            }
            if (e instanceof RuntimeException) {
                RuntimeException re = (RuntimeException)e;
                throw re;
            }
            throw new RuntimeException(e);
        }
    }

    static int squareDistance(MemorySegment a, MemorySegment b, int length) {
        try {
            return SQUARE_DISTANCE.invokeExact(a, b, length);
        }
        catch (Throwable e) {
            if (e instanceof Error) {
                Error err = (Error)e;
                throw err;
            }
            if (e instanceof RuntimeException) {
                RuntimeException re = (RuntimeException)e;
                throw re;
            }
            throw new RuntimeException(e);
        }
    }

    static boolean assertSegments(MemorySegment a, MemorySegment b, int length) {
        return a.isNative() && a.byteSize() >= (long)length && b.isNative() && b.byteSize() >= (long)length;
    }
}

