/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.sagemakergeneric;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.shaker.processors.expr.TokenizedText;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Arrays;

public interface GenericEmbeddingLLMMarshall {
    public static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.generic_text_embedding_marshall");

    public JsonObject prepareInputsEmbedding(LLMClient.EmbeddingQuery var1) throws IllegalArgumentException;

    public LLMClient.SimpleEmbeddingResponse parseEmbeddingResponse(JsonElement var1) throws IOException;

    public static GenericEmbeddingLLMMarshall get(GenericLLMHandling family, AbstractLLMConnection.AbstractLLMConnectionParams params) {
        switch (family) {
            case AMAZON_TITAN_TEXT_EMBEDDING: {
                return new AmazonTitanTextEmbeddingMarshall();
            }
            case AMAZON_TITAN_TEXT_IMAGE_EMBEDDING: {
                return new AmazonTitanTextImageEmbeddingMarshall();
            }
            case COHERE_EMBED: {
                return new CohereEmbeddingLLMMarshall();
            }
        }
        throw new Error("Unknown GenericLLMHandling family for text embedding: " + String.valueOf((Object)family));
    }

    public static class AmazonTitanTextEmbeddingMarshall
    implements GenericEmbeddingLLMMarshall {
        @Override
        public JsonObject prepareInputsEmbedding(LLMClient.EmbeddingQuery query) throws IllegalArgumentException {
            JF.ObjectBuilder ob = JF.obj();
            if (query.hasText()) {
                ob = ob.with("inputText", query.text);
            }
            if (query.hasImage()) {
                throw new IllegalArgumentException("Query with image are not supported for text-only Embedding models");
            }
            return ob.get();
        }

        @Override
        public LLMClient.SimpleEmbeddingResponse parseEmbeddingResponse(JsonElement response) throws IOException {
            JsonObject jo = response instanceof JsonArray ? (JsonObject)((JsonArray)response).get(0) : (JsonObject)response;
            JsonArray embeddingJson = jo.get("embedding").getAsJsonArray();
            LLMClient.SimpleEmbeddingResponse ret = new LLMClient.SimpleEmbeddingResponse();
            ret.embedding = new double[embeddingJson.size()];
            for (int i = 0; i < embeddingJson.size(); ++i) {
                ret.embedding[i] = embeddingJson.get(i).getAsDouble();
            }
            ret.promptTokens = jo.get("inputTextTokenCount").getAsInt();
            return ret;
        }
    }

    public static class AmazonTitanTextImageEmbeddingMarshall
    extends AmazonTitanTextEmbeddingMarshall {
        @Override
        public JsonObject prepareInputsEmbedding(LLMClient.EmbeddingQuery query) {
            JF.ObjectBuilder ob = JF.obj();
            if (query.hasText()) {
                ob = ob.with("inputText", query.text);
            }
            if (query.hasImage()) {
                ob = ob.with("inputImage", query.inlineImage);
            }
            return ob.get();
        }
    }

    public static class CohereEmbeddingLLMMarshall
    implements GenericEmbeddingLLMMarshall {
        @Override
        public JsonObject prepareInputsEmbedding(LLMClient.EmbeddingQuery query) throws IllegalArgumentException {
            JF.ObjectBuilder ob = JF.obj();
            if (query.hasText()) {
                ob.with("texts", Arrays.asList(query.text));
                ob.with("input_type", "search_document");
            }
            if (query.hasImage()) {
                throw new IllegalArgumentException("Image query is not supported for Cohere Embedding models");
            }
            return ob.get();
        }

        @Override
        public LLMClient.SimpleEmbeddingResponse parseEmbeddingResponse(JsonElement response) throws IOException {
            JsonObject jo = response instanceof JsonArray ? (JsonObject)((JsonArray)response).get(0) : (JsonObject)response;
            JsonElement embeddingJsonEl = jo.get("embeddings");
            if (!(embeddingJsonEl instanceof JsonArray)) {
                embeddingJsonEl = embeddingJsonEl.getAsJsonObject().get("float");
            }
            JsonArray embeddingArray = embeddingJsonEl.getAsJsonArray().get(0).getAsJsonArray();
            LLMClient.SimpleEmbeddingResponse ret = new LLMClient.SimpleEmbeddingResponse();
            ret.embedding = new double[embeddingArray.size()];
            for (int i = 0; i < embeddingArray.size(); ++i) {
                ret.embedding[i] = embeddingArray.get(i).getAsDouble();
            }
            String inputText = jo.get("texts").getAsJsonArray().get(0).getAsString();
            ret.promptTokens = (int)(2.5f * (float)new TokenizedText(inputText).size());
            ret.tokenCountsAreEstimated = true;
            return ret;
        }
    }
}

