package com.customllm.llm;

import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;

import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpPost;
import com.dataiku.dss.shadelib.com.nimbusds.oauth2.sdk.ParseException;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.common.rpc.ExternalJSONAPIClient.EntityAndRequest;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.connections.AbstractLLMConnection.HTTPBasedLLMNetworkSettings;
import com.dataiku.dip.custom.PluginSettingsResolver.ResolvedSettings;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.llm.custom.RateLimitingJSONAPIClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClient.ChatMessage;
import com.dataiku.dip.llm.online.LLMClient.CompletionQuery;
import com.dataiku.dip.llm.online.LLMClient.EmbeddingQuery;
import com.dataiku.dip.llm.online.LLMClient.EmbeddingSettings;
import com.dataiku.dip.llm.online.LLMClient.SimpleCompletionResponse;
import com.dataiku.dip.llm.online.LLMClient.SimpleEmbeddingResponse;
import com.dataiku.dip.llm.online.LLMClient.StreamedCompletionResponseChunk;
import com.dataiku.dip.llm.online.LLMClient.StreamedCompletionResponseConsumer;
import com.dataiku.dip.llm.online.LLMClient.StreamedCompletionResponseFooter;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.OpenAIChatCompletionSettingsValidator;
import com.dataiku.dip.llm.online.openai.OpenAIMode;
import com.dataiku.dip.llm.online.openai.OpenAIPricing;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatChunkResponse;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatChunkResponseAdapter;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatQuery;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatQueryAdapter;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatResponse;
import com.dataiku.dip.llm.online.openai.api.chatcompletions.OpenAIChatResponseAdapter;
import com.dataiku.dip.recipes.nlp.common.LLMCompletionSettings;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsage.InternalLLMUsageData;
import com.dataiku.dip.resourceusage.ComputeResourceUsage.LLMUsageData;
import com.dataiku.dip.resourceusage.ComputeResourceUsage.LLMUsageType;
import com.dataiku.dip.security.model.OAuth2Client;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder.HTTPSSEEvent;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JF.ObjectBuilder;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.StringUtils;
import com.google.gson.*;


public class azureOpenAIAPIMCustomLLMPlugin extends CustomLLMClient {
    public azureOpenAIAPIMCustomLLMPlugin() {
    }

    private String endpointUrl;
    private String model;
    private boolean shouldUseMaxCompletionTokens;
    private String auth_type;
    private ResolvedSettings rs;
    private ExternalJSONAPIClient client;
    private InternalLLMUsageData usageData = new LLMUsageData();
    private HTTPBasedLLMNetworkSettings networkSettings = new HTTPBasedLLMNetworkSettings();
    private int maxParallelism = 1;

    private String subscriptionKeyHeaderName;
    private boolean sendCorrelationId;
    private boolean sendClientId;
    private String correlationIdHeaderName;

    private OAuth2Client oauth2client;
    private String OAuthAccessToken;

    private OpenAIChatCompletionSettingsValidator chatCompletionValidator;

    private String tokenEndpoint;
    private String scope;
    private String clientId;
    private String clientSecret;

    private String subscription_key;

    private enum MaxTokensAPIMode {
        MODERN,
        LEGACY
    }

    private MaxTokensAPIMode parseMaxTokensApiMode(JsonElement value) {
        // If empty, default to legacy, as this was the behaviour of the plugin before this parameter was added
        if (value == null) {
            return MaxTokensAPIMode.LEGACY;
        }
        String stringValue = value.getAsString();
        if (StringUtils.nullIfBlank(stringValue) == null) {
            return MaxTokensAPIMode.LEGACY;
        }
        return MaxTokensAPIMode.valueOf(stringValue);
    }

    @Override
    public void init(ResolvedSettings settings) {
        this.rs = settings;

        endpointUrl = rs.config.get("endpoint_url").getAsString();
        model = rs.config.get("model").getAsString();
        shouldUseMaxCompletionTokens = parseMaxTokensApiMode(rs.config.get("maxTokensApiMode")) == MaxTokensAPIMode.MODERN;

        maxParallelism = rs.config.get("maxParallelism").getAsNumber().intValue();
        networkSettings.queryTimeoutMS = rs.config.get("networkTimeout").getAsNumber().intValue();
        networkSettings.maxRetries = rs.config.get("maxRetries").getAsNumber().intValue();
        networkSettings.initialRetryDelayMS = rs.config.get("firstRetryDelay").getAsNumber().longValue();
        networkSettings.retryDelayScalingFactor = rs.config.get("retryDelayScale").getAsNumber().doubleValue();

        // Plugin parameters
        subscriptionKeyHeaderName = rs.pluginConfig.get("subscription_header_name").getAsString();
        sendCorrelationId = rs.pluginConfig.get("send_correlation_id").getAsBoolean();
        correlationIdHeaderName = rs.pluginConfig.get("correlation_id_header_name").getAsString();
        sendClientId = rs.pluginConfig.get("send_client_id").getAsBoolean();

        // Restrict header names to A-Za-z0-9_- by removing other characters
        subscriptionKeyHeaderName = subscriptionKeyHeaderName.replaceAll("[^\\w-]+", "");
        correlationIdHeaderName = correlationIdHeaderName.replaceAll("[^\\w-]+", "");

        // Authentication
        auth_type = rs.config.get("auth_type").getAsString();
        logger.info("Using auth_type: " + auth_type);

        String tmp_subscription_key = "";
        if (auth_type.equals("OAUTH2_CLIENT_CREDENTIALS")) {
            tmp_subscription_key = rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("subscription_key").getAsString();
            tokenEndpoint = rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("token_endpoint").getAsString();
            scope = rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("scope").getAsString();
            clientId = rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("client_id").getAsString();
            clientSecret = rs.config.get("entra_id_oauth_client_credentials").getAsJsonObject().get("client_secret").getAsString();

            // Build OAuth2Client
            try {
                oauth2client = new OAuth2Client.Builder()
                        .authorizationEndpoint(tokenEndpoint) // Authorization Endpoint not needed for client credentials grant (but must be populated)
                        .tokenEndpoint(tokenEndpoint)
                        .clientId(clientId)
                        .clientSecret(clientSecret)
                        .scope(scope)
                        .useGlobalProxy(true)
                        .useAccessTokenCache(false) // oauth2client Access Token cache not relevant for client credentials grant type
                        .build();
            } catch (DKUSecurityException e) {
                throw new RuntimeException("Failed to build OAuth2Client for Client Credentials", e);
            }
            // Acquire Access Token from oauth2client Credentials Grant - with useCache=true
            try {
                OAuth2Client.AccessTokenResult accessTokenResponse = oauth2client.acquireAccessTokenResultWithClientCredentialsGrant(true);
                OAuthAccessToken = accessTokenResponse.getAccessToken();
                Long secondsToExpiry = accessTokenResponse.getTimeLeft() / 1000L;
                logger.debug("OAuth2 access token acquired. Expires in: " + secondsToExpiry + " seconds");
            } catch (URISyntaxException | IOException | DKUSecurityException | ParseException e) {
                throw new RuntimeException("Failed to get OAuth2 access token with Client Credentials grant", e);
            }

        } else if (auth_type.equals("OAUTH2_PER_USER")) {
            tmp_subscription_key = rs.config.get("entra_id_oauthperuser").getAsJsonObject().get("subscription_key").getAsString();
            OAuthAccessToken = rs.config.get("entra_id_oauthperuser").getAsJsonObject().get("entraid_oauth").getAsString();
        }
        subscription_key = tmp_subscription_key;

        chatCompletionValidator = new OpenAIChatCompletionSettingsValidator();

        client = new RateLimitingJSONAPIClient(endpointUrl, null, true, ApplicationConfigurator.getProxySettings()) {

            @Override
            protected HttpPost newPost(String path) throws IOException, ExecutionException {
                HttpPost post = new HttpPost(path);

                setAdditionalHeadersInRequest(post);
                post.addHeader("Content-Type", "application/json");
                post.addHeader(subscriptionKeyHeaderName, subscription_key);
                if (sendCorrelationId) {
                    String correlation_id = "dataiku-" + UUID.randomUUID().toString();
                    post.addHeader(correlationIdHeaderName, correlation_id);
                    logger.debug("correlation id: " + correlation_id);
                }
                if (sendClientId) {
                    post.addHeader("Client-ID", clientId);
                }
                post.addHeader("Authorization", ("Bearer " + OAuthAccessToken));
                return post;
            }

        };
    }

    @Override
    public int getMaxParallelism() {
        return maxParallelism;
    }

    @Override
    public RateLimitingRetrySettings getRetrySettings() {
        return networkSettings.toRateLimitingRetrySettings();
    }

    @Override
    public List<SimpleCompletionResponse> completeBatch(List<LLMClient.CompletionQuery> completionQueries) throws IOException {
        List<SimpleCompletionResponse> ret = new ArrayList<>();
        for (CompletionQuery query : completionQueries) {

            long before = System.currentTimeMillis();
            SimpleCompletionResponse scr = null;

            scr = chatComplete(model, query.messages, query.settings);
            scr.estimatedCost = getEstimatedCompletionCost(scr.promptTokens, scr.completionTokens);

            usageData.incrementTotalComputationTimeMS(System.currentTimeMillis() - before);
            usageData.incrementTotalPromptTokens(scr.promptTokens);
            usageData.incrementTotalCompletionTokens(scr.completionTokens);
            usageData.incrementEstimatedCostUSD(scr.estimatedCost);

            ret.add(scr);
        }

        return ret;
    }

    @Override
    public List<SimpleEmbeddingResponse> embedBatch(List<EmbeddingQuery> queries, EmbeddingSettings settings) throws IOException {
        List<SimpleEmbeddingResponse> ret = new ArrayList<>();
        for (EmbeddingQuery query : queries) {
            long before = System.currentTimeMillis();
            SimpleEmbeddingResponse ser = embed(model, query.text);

            ser.estimatedCost = (OpenAIPricing.getOpenAIEmbeddingCostPer1KTokens(model) * ser.promptTokens) / 1000; //AzureOpenAI pricing uses same as OpenAI pricing for embeddings for now.

            usageData.incrementTotalComputationTimeMS(System.currentTimeMillis() - before);
            usageData.incrementTotalPromptTokens(ser.promptTokens);
            usageData.incrementEstimatedCostUSD(ser.estimatedCost);

            ret.add(ser);
        }
        return ret;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    private SimpleCompletionResponse chatComplete(String model, List<ChatMessage> messages, LLMCompletionSettings cs) throws IOException {
        CoreCompletionSettings ccs = getCoreCompletionSettings(cs);
        chatCompletionValidator.validate(OpenAIMode.AZURE_OPENAI, ccs);
        OpenAIChatQuery query = OpenAIChatQueryAdapter.adapt(OpenAIMode.AZURE_OPENAI, model, messages, ccs, shouldUseMaxCompletionTokens);

        logger.info("posting chat completion to: " + endpointUrl);
        JsonElement jsonResponse = client.postObjectToJSON(endpointUrl, networkSettings.queryTimeoutMS, JsonElement.class, query);
        logger.info("Json Response: " + jsonResponse);
        OpenAIChatResponse response = JSON.parse(jsonResponse, OpenAIChatResponse.class);
        String refusal = OpenAIChatResponseAdapter.getRefusal(response);
        if (refusal != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusal);
            if (response.usage != null) {
                refusalException.completionTokens = response.usage.completionTokens;
                refusalException.promptTokens = response.usage.promptTokens;
                refusalException.totalTokens = response.usage.totalTokens;
                refusalException.estimatedCost = getEstimatedCompletionCost(refusalException.promptTokens, refusalException.completionTokens);
            }
            throw refusalException;
        }

        return OpenAIChatResponseAdapter.adapt(response);
    }

    private SimpleEmbeddingResponse embed(String model, String text) throws IOException {
        ObjectBuilder ob = JF.obj().with("input", text).with("model", model);

        logger.info("posting embedding to: " + endpointUrl);
        OpenAIEmbeddingResponse rcr = client.postObjectToJSON(endpointUrl, networkSettings.queryTimeoutMS,
                OpenAIEmbeddingResponse.class, ob.get());

        if (rcr.data.size() != 1) {
            throw new IOException("LLM did not respond with valid embeddings");
        }

        SimpleEmbeddingResponse ret = new SimpleEmbeddingResponse();
        ret.embedding = rcr.data.get(0).embedding;
        ret.promptTokens = rcr.usage.total_tokens;
        return ret;
    }

    @Override
    public boolean supportsStream() {
        return true;
    }

    @Override
    public void streamComplete(LLMClient.CompletionQuery query, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        streamChatComplete(consumer, model, query.messages, query.settings);
    }

    private void streamChatComplete(StreamedCompletionResponseConsumer consumer, String model, List<ChatMessage> messages, LLMCompletionSettings cs) throws Exception {
        CoreCompletionSettings ccs = getCoreCompletionSettings(cs);
        chatCompletionValidator.validate(OpenAIMode.AZURE_OPENAI, ccs);
        OpenAIChatQuery query = OpenAIChatQueryAdapter.adaptForStreaming(OpenAIMode.AZURE_OPENAI, model, messages, ccs, shouldUseMaxCompletionTokens);

        logger.info("posting stream chat complete to: " + endpointUrl);
        EntityAndRequest ear = client.postJSONToStreamAndRequest(endpointUrl, networkSettings.queryTimeoutMS, query);
        SSEDecoder decoder = new SSEDecoder(ear.entity.getContent());
        consumer.onStreamStarted();

        OpenAIChatResponse.Usage usage = null;
        LLMClient.FinishReason finishReason = null;
        StringBuilder refusalBuilder = null;
        while (true) {
            HTTPSSEEvent event = decoder.next();
            if (logger.isTraceEnabled()) {
                logger.trace("Received raw event from LLM: " + JSON.json(event));
            }

            if (event == null || event.data == null) {
                logger.info("End of LLM stream");
                break;
            }

            if (event.data.equals("[DONE]")) {
                logger.info("Received explicit end marker from LLM stream");
                break;
            }

            OpenAIChatChunkResponse response = JSON.parse(event.data, OpenAIChatChunkResponse.class);

            if (response.usage != null) {
                usage = response.usage;
            }

            if (response.choices.isEmpty()) {
                continue; // nothing to do here
            }

            LLMClient.FinishReason reason = OpenAIChatChunkResponseAdapter.extractFinishReason(response);
            if (reason != null) {
                finishReason = reason;
            }

            String refusalChunk = OpenAIChatChunkResponseAdapter.getRefusal(response);
            if (refusalChunk != null) {
                if (refusalBuilder == null) {
                    refusalBuilder = new StringBuilder();
                }
                refusalBuilder.append(refusalChunk);
            }

            if (refusalBuilder == null) {
                StreamedCompletionResponseChunk chunk = OpenAIChatChunkResponseAdapter.adapt(response);
                if (!chunk.isEmpty()) {
                    consumer.onStreamChunk(chunk);
                }
            }
        }

        if (refusalBuilder != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusalBuilder.toString());
            if (usage != null) {
                refusalException.completionTokens = usage.completionTokens;
                refusalException.promptTokens = usage.promptTokens;
                refusalException.totalTokens = usage.totalTokens;
                refusalException.estimatedCost = getEstimatedCompletionCost(refusalException.promptTokens, refusalException.completionTokens);
            }
            throw refusalException;
        }

        StreamedCompletionResponseFooter footer = new StreamedCompletionResponseFooter();
        if (usage != null) {
            footer.completionTokens = usage.completionTokens;
            footer.promptTokens = usage.promptTokens;
            footer.totalTokens = usage.totalTokens;
            footer.estimatedCost = getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
        }

        if (finishReason != null) {
            footer.finishReason = finishReason;
        }

        consumer.onStreamComplete(footer);
    }

    private double getEstimatedCompletionCost(Integer promptTokens, Integer completionTokens) {
        Double promptCost = OpenAIPricing.getAzureOpenAIPromptCostPer1KTokens(model);
        Double completionCost = OpenAIPricing.getAzureOpenAICompletionCostPer1KTokens(model);
        double totalCost = 0.;

        if ((promptCost != null) && (promptTokens != null)) {
            totalCost += promptCost * promptTokens / 1000.;
        }
        if ((completionCost != null) && (completionTokens != null)) {
            totalCost += completionCost * completionTokens / 1000.;
        }
        return totalCost;
    }

    private CoreCompletionSettings getCoreCompletionSettings(LLMCompletionSettings cs) {
        CoreCompletionSettings ccs = new CoreCompletionSettings();
        ccs.maxTokens = cs.maxOutputTokens;
        ccs.temperature = cs.temperature;
        ccs.topP = cs.topP;
        ccs.topK = cs.topK;
        ccs.frequencyPenalty = cs.frequencyPenalty;
        ccs.presencePenalty = cs.presencePenalty;
        ccs.logProbs = cs.logProbs;
        ccs.topLogProbs = cs.topLogProbs;
        ccs.logitBias = cs.logitBias;
        ccs.stopSequences = cs.stopSequences;
        ccs.toolChoice = cs.toolChoice;
        ccs.tools = cs.tools;
        if (cs.responseFormat != null) {
            ccs.responseFormat = cs.responseFormat.toFullResponseFormat();
        }
        return ccs;
    }

    private static class OpenAIEmbeddingResponse {
        List<OpenAIEmbeddingResult> data = new ArrayList<>();
        RawUsageResponse usage;
    }

    private static class OpenAIEmbeddingResult {
        double[] embedding;
    }

    private static class RawUsageResponse {
        int total_tokens;
        int prompt_tokens;
        int completion_tokens;
    }

    private static DKULogger logger = DKULogger.getLogger("dku.llm.customplugin");
}
