/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.openai;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.openai.OpenAIRerankingHandlingMode;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;

public interface AzureAIFoundryRerankingLLMMarshall {
    public static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.azure_ai_foundry_reranking_marshall");

    public JsonObject prepareInputsReranking(String var1, LLMClient.RerankingQuery var2) throws IllegalArgumentException;

    public LLMClient.SingleRerankingResponse parseRerankingResponse(JsonElement var1) throws IOException;

    public String getRerankingEndpoint();

    public static AzureAIFoundryRerankingLLMMarshall get(OpenAIRerankingHandlingMode handlingMode) {
        if (handlingMode == null) {
            throw new Error("Reranking handling mode is null");
        }
        switch (handlingMode) {
            case COHERE: {
                return new CohereRerankingLLMMarshall();
            }
        }
        throw new Error("Unknown handling mode for reranking: " + String.valueOf((Object)handlingMode));
    }

    public static class CohereRerankingLLMMarshall
    implements AzureAIFoundryRerankingLLMMarshall {
        @Override
        public JsonObject prepareInputsReranking(String modelId, LLMClient.RerankingQuery query) throws IllegalArgumentException {
            JF.ObjectBuilder ob = JF.obj();
            ob.with("model", modelId);
            ob.with("query", query.getQueryText());
            ob.with("documents", query.documents.stream().map(doc -> doc.getText()).toList());
            return ob.get();
        }

        @Override
        public LLMClient.SingleRerankingResponse parseRerankingResponse(JsonElement response) throws IOException {
            JsonObject jo = (JsonObject)response;
            JsonArray results = jo.get("results").getAsJsonArray();
            LLMClient.SingleRerankingResponse ret = new LLMClient.SingleRerankingResponse();
            for (int i = 0; i < results.size(); ++i) {
                JsonObject rankedDocument = results.get(i).getAsJsonObject();
                int index = rankedDocument.get("index").getAsInt();
                float relevanceScore = rankedDocument.get("relevance_score").getAsFloat();
                ret.documents.add(new LLMClient.RerankedDocument(index, relevanceScore));
            }
            return ret;
        }

        @Override
        public String getRerankingEndpoint() {
            return "providers/cohere/v2/rerank";
        }
    }
}

