/*
 * Decompiled with CFR 0.152.
 */
package com.customllm.llm;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.custom.PluginSettingsResolver;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.llm.custom.RateLimitingJSONAPIClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.OpenAIChatCompletionSettingsValidator;
import com.dataiku.dip.llm.online.openai.OpenAIMode;
import com.dataiku.dip.llm.online.openai.OpenAIPricing;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatChunkResponse;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatChunkResponseAdapter;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatQuery;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatQueryAdapter;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatResponse;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatResponseAdapter;
import com.dataiku.dip.recipes.nlp.common.LLMCompletionSettings;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.model.OAuth2Client;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.StringUtils;
import com.dataiku.dss.shadelib.com.nimbusds.oauth2.sdk.ParseException;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpPost;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpRequestBase;
import com.google.gson.JsonElement;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;

public class azureOpenAIAPIMCustomLLMPlugin
extends CustomLLMClient {
    private String endpointUrl;
    private String model;
    private boolean shouldUseMaxCompletionTokens;
    private String auth_type;
    private PluginSettingsResolver.ResolvedSettings rs;
    private ExternalJSONAPIClient client;
    private ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
    private int maxParallelism = 1;
    private String subscriptionKeyHeaderName;
    private boolean sendCorrelationId;
    private String correlationIdHeaderName;
    private OAuth2Client oauth2client;
    private String OAuthAccessToken;
    private OpenAIChatCompletionSettingsValidator chatCompletionValidator;
    private String tokenEndpoint;
    private String scope;
    private String clientId;
    private String clientSecret;
    private String subscription_key;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.customplugin");

    private MaxTokensAPIMode parseMaxTokensApiMode(JsonElement value) {
        if (value == null) {
            return MaxTokensAPIMode.LEGACY;
        }
        String stringValue = value.getAsString();
        if (StringUtils.nullIfBlank((String)stringValue) == null) {
            return MaxTokensAPIMode.LEGACY;
        }
        return MaxTokensAPIMode.valueOf(stringValue);
    }

    public void init(PluginSettingsResolver.ResolvedSettings settings) {
        this.rs = settings;
        this.endpointUrl = this.rs.config.get("endpoint_url").getAsString();
        this.model = this.rs.config.get("model").getAsString();
        this.shouldUseMaxCompletionTokens = this.parseMaxTokensApiMode(this.rs.config.get("maxTokensApiMode")) == MaxTokensAPIMode.MODERN;
        this.maxParallelism = this.rs.config.get("maxParallelism").getAsNumber().intValue();
        this.networkSettings.queryTimeoutMS = this.rs.config.get("networkTimeout").getAsNumber().intValue();
        this.networkSettings.maxRetries = this.rs.config.get("maxRetries").getAsNumber().intValue();
        this.networkSettings.initialRetryDelayMS = this.rs.config.get("firstRetryDelay").getAsNumber().longValue();
        this.networkSettings.retryDelayScalingFactor = this.rs.config.get("retryDelayScale").getAsNumber().doubleValue();
        this.subscriptionKeyHeaderName = this.rs.pluginConfig.get("subscription_header_name").getAsString();
        this.sendCorrelationId = this.rs.pluginConfig.get("send_correlation_id").getAsBoolean();
        this.correlationIdHeaderName = this.rs.pluginConfig.get("correlation_id_header_name").getAsString();
        this.subscriptionKeyHeaderName = this.subscriptionKeyHeaderName.replaceAll("[^\\w-]+", "");
        this.correlationIdHeaderName = this.correlationIdHeaderName.replaceAll("[^\\w-]+", "");
        this.auth_type = this.rs.config.get("auth_type").getAsString();
        logger.info((Object)("Using auth_type: " + this.auth_type));
        String tmp_subscription_key = "";
        if (this.auth_type.equals("OAUTH2_CLIENT_CREDENTIALS")) {
            tmp_subscription_key = this.rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("subscription_key").getAsString();
            this.tokenEndpoint = this.rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("token_endpoint").getAsString();
            this.scope = this.rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("scope").getAsString();
            this.clientId = this.rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("client_id").getAsString();
            this.clientSecret = this.rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("client_secret").getAsString();
            try {
                this.oauth2client = new OAuth2Client.Builder().authorizationEndpoint(this.tokenEndpoint).tokenEndpoint(this.tokenEndpoint).clientId(this.clientId).clientSecret(this.clientSecret).scope(this.scope).useGlobalProxy(true).useAccessTokenCache(false).build();
            }
            catch (DKUSecurityException e) {
                throw new RuntimeException("Failed to build OAuth2Client for Client Credentials", e);
            }
            try {
                OAuth2Client.AccessTokenResult accessTokenResponse = this.oauth2client.acquireAccessTokenResultWithClientCredentialsGrant(true);
                this.OAuthAccessToken = accessTokenResponse.getAccessToken();
                Long secondsToExpiry = accessTokenResponse.getTimeLeft() / 1000L;
                logger.debug((Object)("OAuth2 access token acquired. Expires in: " + secondsToExpiry + " seconds"));
            }
            catch (DKUSecurityException | ParseException | IOException | URISyntaxException e) {
                throw new RuntimeException("Failed to get OAuth2 access token with Client Credentials grant", e);
            }
        }
        if (this.auth_type.equals("OAUTH2_PER_USER")) {
            tmp_subscription_key = this.rs.config.get("entra_id_oauthperuser").getAsJsonObject().get("subscription_key").getAsString();
            this.OAuthAccessToken = this.rs.config.get("entra_id_oauthperuser").getAsJsonObject().get("entraid_oauth").getAsString();
        }
        this.subscription_key = tmp_subscription_key;
        this.chatCompletionValidator = new OpenAIChatCompletionSettingsValidator();
        this.client = new RateLimitingJSONAPIClient(this.endpointUrl, null, true, ApplicationConfigurator.getProxySettings()){

            protected HttpPost newPost(String path) throws IOException, ExecutionException {
                HttpPost post = new HttpPost(path);
                this.setAdditionalHeadersInRequest((HttpRequestBase)post);
                post.addHeader("Content-Type", "application/json");
                post.addHeader(azureOpenAIAPIMCustomLLMPlugin.this.subscriptionKeyHeaderName, azureOpenAIAPIMCustomLLMPlugin.this.subscription_key);
                if (azureOpenAIAPIMCustomLLMPlugin.this.sendCorrelationId) {
                    String correlation_id = "dataiku-" + UUID.randomUUID().toString();
                    post.addHeader(azureOpenAIAPIMCustomLLMPlugin.this.correlationIdHeaderName, correlation_id);
                    logger.debug((Object)("correlation id: " + correlation_id));
                }
                post.addHeader("Authorization", "Bearer " + azureOpenAIAPIMCustomLLMPlugin.this.OAuthAccessToken);
                return post;
            }
        };
    }

    public int getMaxParallelism() {
        return this.maxParallelism;
    }

    public CustomLLMClient.RateLimitingRetrySettings getRetrySettings() {
        return this.networkSettings.toRateLimitingRetrySettings();
    }

    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.CompletionQuery> completionQueries) throws IOException {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        for (LLMClient.CompletionQuery query : completionQueries) {
            long before = System.currentTimeMillis();
            LLMClient.SimpleCompletionResponse scr = null;
            scr = this.chatComplete(this.model, query.messages, query.settings);
            scr.estimatedCost = this.getEstimatedCompletionCost(scr.promptTokens, scr.completionTokens);
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(System.currentTimeMillis() - before));
            this.usageData.incrementTotalPromptTokens(scr.promptTokens);
            this.usageData.incrementTotalCompletionTokens(scr.completionTokens);
            this.usageData.incrementEstimatedCostUSD(scr.estimatedCost);
            ret.add(scr);
        }
        return ret;
    }

    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (LLMClient.EmbeddingQuery query : queries) {
            long before = System.currentTimeMillis();
            LLMClient.SimpleEmbeddingResponse ser = this.embed(this.model, query.text);
            ser.estimatedCost = OpenAIPricing.getOpenAIEmbeddingCostPer1KTokens((String)this.model) * (double)ser.promptTokens.intValue() / 1000.0;
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(System.currentTimeMillis() - before));
            this.usageData.incrementTotalPromptTokens(ser.promptTokens);
            this.usageData.incrementEstimatedCostUSD(ser.estimatedCost);
            ret.add(ser);
        }
        return ret;
    }

    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    private LLMClient.SimpleCompletionResponse chatComplete(String model, List<LLMClient.ChatMessage> messages, LLMCompletionSettings cs) throws IOException {
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(cs);
        this.chatCompletionValidator.validate(OpenAIMode.AZURE_OPENAI, ccs);
        OpenAIChatQuery query = OpenAIChatQueryAdapter.adapt((OpenAIMode)OpenAIMode.AZURE_OPENAI, (String)model, messages, (CoreCompletionSettings)ccs, (boolean)this.shouldUseMaxCompletionTokens);
        logger.info((Object)("posting chat completion to: " + this.endpointUrl));
        OpenAIChatResponse response = (OpenAIChatResponse)this.client.postObjectToJSON(this.endpointUrl, this.networkSettings.queryTimeoutMS, OpenAIChatResponse.class, (Object)query);
        String refusal = OpenAIChatResponseAdapter.getRefusal((OpenAIChatResponse)response);
        if (refusal != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusal);
            if (response.usage != null) {
                refusalException.completionTokens = response.usage.completionTokens;
                refusalException.promptTokens = response.usage.promptTokens;
                refusalException.totalTokens = response.usage.totalTokens;
                refusalException.estimatedCost = this.getEstimatedCompletionCost(refusalException.promptTokens, refusalException.completionTokens);
            }
            throw refusalException;
        }
        return OpenAIChatResponseAdapter.adapt((OpenAIChatResponse)response);
    }

    private LLMClient.SimpleEmbeddingResponse embed(String model, String text) throws IOException {
        JF.ObjectBuilder ob = JF.obj().with("input", text).with("model", model);
        logger.info((Object)("posting embedding to: " + this.endpointUrl));
        OpenAIEmbeddingResponse rcr = (OpenAIEmbeddingResponse)this.client.postObjectToJSON(this.endpointUrl, this.networkSettings.queryTimeoutMS, OpenAIEmbeddingResponse.class, (Object)ob.get());
        if (rcr.data.size() != 1) {
            throw new IOException("LLM did not respond with valid embeddings");
        }
        LLMClient.SimpleEmbeddingResponse ret = new LLMClient.SimpleEmbeddingResponse();
        ret.embedding = rcr.data.get((int)0).embedding;
        ret.promptTokens = rcr.usage.total_tokens;
        return ret;
    }

    public boolean supportsStream() {
        return true;
    }

    public void streamComplete(LLMClient.CompletionQuery query, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        this.streamChatComplete(consumer, this.model, query.messages, query.settings);
    }

    private void streamChatComplete(LLMClient.StreamedCompletionResponseConsumer consumer, String model, List<LLMClient.ChatMessage> messages, LLMCompletionSettings cs) throws Exception {
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(cs);
        this.chatCompletionValidator.validate(OpenAIMode.AZURE_OPENAI, ccs);
        OpenAIChatQuery query = OpenAIChatQueryAdapter.adaptForStreaming((OpenAIMode)OpenAIMode.AZURE_OPENAI, (String)model, messages, (CoreCompletionSettings)ccs, (boolean)this.shouldUseMaxCompletionTokens);
        logger.info((Object)("posting stream chat complete to: " + this.endpointUrl));
        ExternalJSONAPIClient.EntityAndRequest ear = this.client.postJSONToStreamAndRequest(this.endpointUrl, this.networkSettings.queryTimeoutMS, (Object)query);
        SSEDecoder decoder = new SSEDecoder(ear.entity.getContent());
        consumer.onStreamStarted();
        OpenAIChatResponse.Usage usage = null;
        LLMClient.FinishReason finishReason = null;
        StringBuilder refusalBuilder = null;
        while (true) {
            LLMClient.StreamedCompletionResponseChunk chunk;
            String refusalChunk;
            SSEDecoder.HTTPSSEEvent event = decoder.next();
            if (logger.isTraceEnabled()) {
                logger.trace((Object)("Received raw event from LLM: " + JSON.json((Object)event)));
            }
            if (event == null || event.data == null) {
                logger.info((Object)"End of LLM stream");
                break;
            }
            if (event.data.equals("[DONE]")) {
                logger.info((Object)"Received explicit end marker from LLM stream");
                break;
            }
            OpenAIChatChunkResponse response = (OpenAIChatChunkResponse)JSON.parse((String)event.data, OpenAIChatChunkResponse.class);
            if (response.usage != null) {
                usage = response.usage;
            }
            if (response.choices.isEmpty()) continue;
            LLMClient.FinishReason reason = OpenAIChatChunkResponseAdapter.extractFinishReason((OpenAIChatChunkResponse)response);
            if (reason != null) {
                finishReason = reason;
            }
            if ((refusalChunk = OpenAIChatChunkResponseAdapter.getRefusal((OpenAIChatChunkResponse)response)) != null) {
                if (refusalBuilder == null) {
                    refusalBuilder = new StringBuilder();
                }
                refusalBuilder.append(refusalChunk);
            }
            if (refusalBuilder != null || (chunk = OpenAIChatChunkResponseAdapter.adapt((OpenAIChatChunkResponse)response)).isEmpty()) continue;
            consumer.onStreamChunk(chunk);
        }
        if (refusalBuilder != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusalBuilder.toString());
            if (usage != null) {
                refusalException.completionTokens = usage.completionTokens;
                refusalException.promptTokens = usage.promptTokens;
                refusalException.totalTokens = usage.totalTokens;
                refusalException.estimatedCost = this.getEstimatedCompletionCost(refusalException.promptTokens, refusalException.completionTokens);
            }
            throw refusalException;
        }
        LLMClient.StreamedCompletionResponseFooter footer = new LLMClient.StreamedCompletionResponseFooter();
        if (usage != null) {
            footer.completionTokens = usage.completionTokens;
            footer.promptTokens = usage.promptTokens;
            footer.totalTokens = usage.totalTokens;
            footer.estimatedCost = this.getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
        }
        if (finishReason != null) {
            footer.finishReason = finishReason;
        }
        consumer.onStreamComplete(footer);
    }

    private double getEstimatedCompletionCost(Integer promptTokens, Integer completionTokens) {
        Double promptCost = OpenAIPricing.getAzureOpenAIPromptCostPer1KTokens((String)this.model);
        Double completionCost = OpenAIPricing.getAzureOpenAICompletionCostPer1KTokens((String)this.model);
        double totalCost = 0.0;
        if (promptCost != null && promptTokens != null) {
            totalCost += promptCost * (double)promptTokens.intValue() / 1000.0;
        }
        if (completionCost != null && completionTokens != null) {
            totalCost += completionCost * (double)completionTokens.intValue() / 1000.0;
        }
        return totalCost;
    }

    private CoreCompletionSettings getCoreCompletionSettings(LLMCompletionSettings cs) {
        CoreCompletionSettings ccs = new CoreCompletionSettings();
        ccs.maxTokens = cs.maxOutputTokens;
        ccs.temperature = cs.temperature;
        ccs.topP = cs.topP;
        ccs.topK = cs.topK;
        ccs.frequencyPenalty = cs.frequencyPenalty;
        ccs.presencePenalty = cs.presencePenalty;
        ccs.logProbs = cs.logProbs;
        ccs.topLogProbs = cs.topLogProbs;
        ccs.logitBias = cs.logitBias;
        ccs.stopSequences = cs.stopSequences;
        ccs.toolChoice = cs.toolChoice;
        ccs.tools = cs.tools;
        if (cs.responseFormat != null) {
            ccs.responseFormat = cs.responseFormat.toFullResponseFormat();
        }
        return ccs;
    }

    private static enum MaxTokensAPIMode {
        MODERN,
        LEGACY;

    }

    private static class OpenAIEmbeddingResponse {
        List<OpenAIEmbeddingResult> data = new ArrayList<OpenAIEmbeddingResult>();
        RawUsageResponse usage;

        private OpenAIEmbeddingResponse() {
        }
    }

    private static class OpenAIEmbeddingResult {
        double[] embedding;

        private OpenAIEmbeddingResult() {
        }
    }

    private static class RawUsageResponse {
        int total_tokens;
        int prompt_tokens;
        int completion_tokens;

        private RawUsageResponse() {
        }
    }
}

