/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.anthropic;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatQuery;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatQueryAdapter;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatResponse;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatResponseAdapter;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettingsValidator;
import com.dataiku.dip.llm.online.marshall.FinishReasonResponseAdapter;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.base.MoreObjects;
import com.google.gson.JsonObject;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.util.List;
import javax.annotation.Nullable;

public class RawAnthropicClient {
    private static final String DEFAULT_ENDPOINT_BASE = "https://api.anthropic.com/v1";
    private static final String ANTHROPIC_VERSION = "2023-06-01";
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings;
    ExternalJSONAPIClient client;
    private static final CoreCompletionSettingsValidator textCompletionValidator = new CoreCompletionSettingsValidator("Anthropic Claude (text)").allowMaxTokens().allowTemperature().allowTopK().allowTopP().allowStopSequences();
    private static final CoreCompletionSettingsValidator chatCompletionValidator = new CoreCompletionSettingsValidator("Anthropic Claude (chat)", textCompletionValidator).allowTools();
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.anthropic.client");

    public RawAnthropicClient(String apiKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings) {
        this.networkSettings = networkSettings;
        this.client = OnlineLLMUtils.getExternalJSONClientWithRetryStrategy(DEFAULT_ENDPOINT_BASE, null, false, proxySettings, networkSettings);
        this.client.addHeader("X-API-Key", apiKey);
        this.client.addHeader("anthropic-version", ANTHROPIC_VERSION);
    }

    public void close() {
        this.client.close();
    }

    public LLMClient.SimpleCompletionResponse complete(String model, String prompt, CoreCompletionSettings ccs) throws IOException {
        textCompletionValidator.validate(ccs);
        ccs.maxTokens = (Integer)MoreObjects.firstNonNull((Object)ccs.maxTokens, (Object)256);
        JsonObject query = RawAnthropicClient.buildCompleteQuery(model, prompt, ccs);
        RawGenerationResponse rcr = (RawGenerationResponse)this.client.postObjectToJSON("/complete", RawGenerationResponse.class, (Object)query);
        if (rcr.completion == null) {
            throw new IOException("Anthropic API did not respond with valid completion");
        }
        LLMClient.SimpleCompletionResponse ret = new LLMClient.SimpleCompletionResponse();
        ret.text = rcr.completion;
        ret.finishReason = FinishReasonResponseAdapter.adapt(rcr.stopReason);
        return ret;
    }

    public static JsonObject buildCompleteQuery(String model, String prompt, CoreCompletionSettings ccs) {
        JF.ObjectBuilder ob = JF.obj().with("prompt", prompt);
        if (model != null) {
            ob.with("model", model);
        }
        if (ccs.maxTokens != null) {
            ob.with("max_tokens_to_sample", (Number)ccs.maxTokens);
        }
        if (ccs.temperature != null) {
            ob.with("temperature", (Number)ccs.temperature);
        }
        if (ccs.topK != null) {
            ob.with("top_k", (Number)ccs.topK);
        }
        if (ccs.topP != null) {
            ob.with("top_p", (Number)ccs.topP);
        }
        if (ccs.stopSequences != null && !ccs.stopSequences.isEmpty()) {
            ob.with("stop_sequences", ccs.stopSequences);
        }
        JsonObject query = ob.get();
        return query;
    }

    public LLMClient.SimpleCompletionResponse chatComplete(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws IOException {
        chatCompletionValidator.validate(ccs);
        ccs.maxTokens = (Integer)MoreObjects.firstNonNull((Object)ccs.maxTokens, (Object)256);
        AnthropicChatQuery query = AnthropicChatQueryAdapter.adapt(model, messages, ccs);
        logger.trace(() -> String.format("Anthropic raw chat completion query: %s", JSON.pretty((Object)query)));
        String endpoint = "messages";
        AnthropicChatResponse response = (AnthropicChatResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, AnthropicChatResponse.class, (Object)query);
        logger.trace(() -> String.format("Anthropic raw chat completion response: %s", JSON.pretty((Object)response)));
        return AnthropicChatResponseAdapter.adapt(response);
    }

    private static class RawGenerationResponse {
        String completion;
        @Nullable
        @SerializedName(value="stop_reason")
        String stopReason;

        private RawGenerationResponse() {
        }
    }
}

