/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.sagemakergeneric;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;

public interface GenericRerankingLLMMarshall {
    public static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.generic_reranking_marshall");

    public JsonObject prepareInputsReranking(LLMClient.RerankingQuery var1) throws IllegalArgumentException;

    public LLMClient.SingleRerankingResponse parseRerankingResponse(JsonElement var1) throws IOException;

    public static GenericRerankingLLMMarshall get(GenericLLMHandling family, AbstractLLMConnection.AbstractLLMConnectionParams params) {
        switch (family) {
            case COHERE_RERANK: {
                return new CohereRerankingLLMMarshall();
            }
        }
        throw new Error("Unknown GenericLLMHandling family for reranking: " + String.valueOf((Object)family));
    }

    public static class CohereRerankingLLMMarshall
    implements GenericRerankingLLMMarshall {
        @Override
        public JsonObject prepareInputsReranking(LLMClient.RerankingQuery query) throws IllegalArgumentException {
            JF.ObjectBuilder ob = JF.obj();
            ob.with("query", query.getQueryText());
            ob.with("api_version", (Number)2);
            ob.with("documents", query.documents.stream().map(doc -> doc.getText()).toList());
            return ob.get();
        }

        @Override
        public LLMClient.SingleRerankingResponse parseRerankingResponse(JsonElement response) throws IOException {
            JsonObject jo = (JsonObject)response;
            JsonArray results = jo.get("results").getAsJsonArray();
            LLMClient.SingleRerankingResponse ret = new LLMClient.SingleRerankingResponse();
            for (int i = 0; i < results.size(); ++i) {
                JsonObject rankedDocument = results.get(i).getAsJsonObject();
                int index = rankedDocument.get("index").getAsInt();
                float relevanceScore = rankedDocument.get("relevance_score").getAsFloat();
                ret.documents.add(new LLMClient.RerankedDocument(index, relevanceScore));
            }
            return ret;
        }
    }
}

