/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.cohere;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettingsValidator;
import com.dataiku.dip.llm.online.marshall.FinishReasonResponseAdapter;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class RawCohereClient {
    private static final String DEFAULT_ENDPOINT_BASE = "https://api.cohere.ai/v1";
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings;
    private ExternalJSONAPIClient client;
    private static final CoreCompletionSettingsValidator chatAPIValidator = new CoreCompletionSettingsValidator("Cohere (chat)").allowMaxTokens().allowTemperature().allowTopK().allowTopP().allowFrequencyPenalty().allowPresencePenalty().allowStopSequences();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.cohere.client");

    public RawCohereClient(String apiKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean forceContentLength) {
        this.networkSettings = networkSettings;
        this.client = OnlineLLMUtils.getExternalJSONClientWithRetryStrategy(DEFAULT_ENDPOINT_BASE, null, false, proxySettings, networkSettings);
        this.client.addHeader("Authorization", "Bearer " + apiKey);
        if (forceContentLength) {
            this.client.forceContentLength = true;
        }
    }

    public void close() {
        this.client.close();
    }

    public static JsonObject buildCompleteQuery(String model, String prompt, CoreCompletionSettings ccs) {
        JF.ObjectBuilder ob = JF.obj().with("prompt", prompt);
        if (model != null) {
            ob.with("model", model);
        }
        if (ccs.maxTokens != null) {
            ob.with("max_tokens", (Number)ccs.maxTokens);
        } else {
            ob.with("max_tokens", (Number)256);
        }
        if (ccs.temperature != null) {
            ob.with("temperature", (Number)ccs.temperature);
        }
        if (ccs.topK != null) {
            ob.with("k", (Number)ccs.topK);
        }
        if (ccs.topP != null) {
            ob.with("p", (Number)ccs.topP);
        }
        if (ccs.frequencyPenalty != null) {
            ob.with("frequency_penalty", (Number)ccs.frequencyPenalty);
        }
        if (ccs.presencePenalty != null) {
            ob.with("presence_penalty", (Number)ccs.presencePenalty);
        }
        if (ccs.stopSequences != null && !ccs.stopSequences.isEmpty()) {
            ob.with("stop_sequences", ccs.stopSequences);
        }
        return ob.get();
    }

    public LLMClient.SimpleCompletionResponse chatComplete(List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws IOException {
        chatAPIValidator.validate(ccs);
        JsonObject rawChatQuery = RawCohereClient.buildChatQuery(messages, ccs);
        String endpoint = "chat";
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Raw Cohere LLM chat completion query: " + String.valueOf(rawChatQuery)));
        }
        RawChatResponse rawChatResponse = (RawChatResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, RawChatResponse.class, (Object)rawChatQuery);
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Raw Cohere LLM chat completion response: " + JSON.json((Object)rawChatResponse)));
        }
        if (rawChatResponse.text == null) {
            throw new IOException("Cohere LLM did not respond with valid chat completion");
        }
        LLMClient.SimpleCompletionResponse ret = new LLMClient.SimpleCompletionResponse();
        ret.text = rawChatResponse.text;
        ret.finishReason = FinishReasonResponseAdapter.adapt(rawChatResponse.finishReason);
        ret.promptTokens = rawChatResponse.meta.billed_units.input_tokens;
        ret.completionTokens = rawChatResponse.meta.billed_units.output_tokens;
        return ret;
    }

    public static JsonObject buildChatQuery(List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        RawChatQuery rawChatQuery = new RawChatQuery();
        ArrayList<String> systemPromptList = new ArrayList<String>();
        String userMessage = null;
        int userMessageIndex = -1;
        for (int i = 0; i < messages.size(); ++i) {
            LLMClient.ChatMessage message = messages.get(i);
            if ("system".equals(message.role)) {
                systemPromptList.add(message.getText());
                continue;
            }
            RawChatMessage chatMessage = new RawChatMessage();
            chatMessage.role = RawCohereClient.convertRoleToCohereRole(message.role);
            chatMessage.message = message.getText();
            rawChatQuery.chat_history.add(chatMessage);
            if (!"user".equals(message.role)) continue;
            userMessage = message.getText();
            userMessageIndex = rawChatQuery.chat_history.size() - 1;
        }
        rawChatQuery.preamble = String.join((CharSequence)"\n\n", systemPromptList);
        rawChatQuery.message = userMessage;
        if (userMessageIndex != -1) {
            rawChatQuery.chat_history.remove(userMessageIndex);
        }
        rawChatQuery.max_tokens = ccs.maxTokens != null ? ccs.maxTokens : Integer.valueOf(2048);
        if (ccs.stopSequences != null && !ccs.stopSequences.isEmpty()) {
            rawChatQuery.stop_sequences = ccs.stopSequences;
        }
        rawChatQuery.stream = false;
        if (ccs.temperature != null) {
            rawChatQuery.temperature = ccs.temperature;
        }
        if (ccs.topP != null) {
            rawChatQuery.p = ccs.topP;
        }
        if (ccs.topK != null) {
            rawChatQuery.k = ccs.topK;
        }
        if (ccs.frequencyPenalty != null) {
            rawChatQuery.frequency_penalty = ccs.frequencyPenalty;
        }
        if (ccs.presencePenalty != null) {
            rawChatQuery.presence_penalty = ccs.presencePenalty;
        }
        return JSON.toJsonObject((Object)rawChatQuery);
    }

    private static String convertRoleToCohereRole(String role) {
        if ("assistant".equals(role)) {
            return "CHATBOT";
        }
        return role.toUpperCase();
    }

    private static class RawChatResponse {
        String text;
        @SerializedName(value="finish_reason")
        String finishReason;
        RawMetaResponse meta;

        private RawChatResponse() {
        }
    }

    private static class RawMetaResponse {
        RawMetaTokensResponse billed_units;

        private RawMetaResponse() {
        }
    }

    private static class RawMetaTokensResponse {
        Integer input_tokens;
        Integer output_tokens;

        private RawMetaTokensResponse() {
        }
    }

    private static class RawChatQuery {
        String model;
        List<RawChatMessage> chat_history = new ArrayList<RawChatMessage>();
        String message;
        String preamble;
        Integer max_tokens;
        List<String> stop_sequences;
        Boolean stream;
        Double temperature;
        Double p;
        Integer k;
        Double frequency_penalty;
        Double presence_penalty;

        private RawChatQuery() {
        }
    }

    private static class RawChatMessage {
        String role;
        String message;

        private RawChatMessage() {
        }
    }
}

