package com.customllm.llm;

import java.io.IOException;
import com.dataiku.dip.utils.JF;
import java.util.List;
import com.dataiku.dip.utils.JF.ObjectBuilder;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClient.ChatMessage;
import com.dataiku.dip.connections.AbstractLLMConnection.HTTPBasedLLMNetworkSettings;
import com.dataiku.dip.llm.online.LLMClient.StreamedCompletionResponseChunk;
import com.dataiku.dip.llm.online.LLMClient.StreamedCompletionResponseConsumer;
import com.dataiku.dip.llm.online.LLMClient.StreamedCompletionResponseFooter;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatQueryAdapter;
import com.dataiku.dip.llm.online.openai.OpenAIMode;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatQuery;

import com.dataiku.dip.llm.online.openai.api.OpenAIChatResponse;
import com.dataiku.common.rpc.ExternalJSONAPIClient.EntityAndRequest;
import com.dataiku.dip.llm.online.LLMClient.SimpleCompletionResponse;
import com.dataiku.dip.utils.JSON;
import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder.HTTPSSEEvent;
import javax.annotation.Nullable;
import com.google.gson.annotations.SerializedName;

import com.dataiku.dip.utils.DKULogger;


public class GptApiImplementation implements CustomApiImplementation {
    ObjectBuilder ob;
    CoreCompletionSettings ccs;
    OpenAIChatQuery query ;
    String endpoint;
    String model;
    public GptApiImplementation(String endpointUrl, String model) {
        //this.endpoint = endpointUrl  + "/" + model + "/chat/completions";

        this.endpoint = endpointUrl;
        this.model=model;
        this.ob = JF.obj();
    }

    public class OpenAIChatChunkResponse {
        public List<Choice> choices;
        @Nullable
        public OpenAIChatResponse.Usage usage;
     
        public OpenAIChatChunkResponse() {
        }
     }
     public static class Choice {
        public Delta delta;

        @Nullable
        @SerializedName("finish_reason")
        public String finishReason;

        @Nullable
        @SerializedName("logprobs")
        // this is the same field as in the regular chat api response
        public LogProbs logProbs;
    }

    public static class PartToolCall {
        @Nullable
        public Integer index;

        @Nullable
        public String id;

        public PartialToolCallFunction function;
    }

    public static class PartialToolCallFunction {
        @Nullable
        public String name;

        @Nullable
        public String arguments; // this is a JSON string
    }
    public static class Delta {
        @Nullable
        public String content;

        @Nullable
        @SerializedName("tool_calls")
        public List<PartToolCall> toolCalls;

        @Nullable
        public String refusal;
    }
    public static class RawChatCompletionChoice {
        @SerializedName("finish_reason")
        public String finishReason;

        public ChoiceMessage message;

        @Nullable
        @SerializedName("logprobs")
        public LogProbs logProbs;
    }
    
    public static class ChoiceMessage {
        public String role;

        @Nullable
        public String content;

        @Nullable
        @SerializedName("tool_calls")
        public List<ToolCall> toolCalls;

        @Nullable
        public String refusal;
    }

    public static class LogProbs {
        @Nullable
        public List<LogProbContent> content;
    }

    public static class LogProbContent {
        public String token;

        @SerializedName("logprob")
        public double logProb;

        @SerializedName("top_logprobs")
        public List<TopLogProbContent> topLogProbs;
    }

    public static class TopLogProbContent {
        public String token;

        @SerializedName("logprob")
        public double logProb;
    }

    public static class ToolCall {
        public String id;
        public ToolCallFunction function;
    }

    public static class ToolCallFunction {
        public String name;
        public String arguments; // this is a JSON string
    }

    private static class Usage {
        @SerializedName("completion_tokens")
        public int completionTokens;

        @SerializedName("prompt_tokens")
        public int promptTokens;

        @SerializedName("total_tokens")
        public int totalTokens;
    }
    private static class RawUsageResponse {
        int total_tokens;
        int prompt_tokens;
        int completion_tokens;
    }


    private static class RawChatCompletionResponse {
        List<RawChatCompletionChoice> choices;
        RawUsageResponse usage;
    }
    
//Add in ToolChoice and Tools
    public void addSettingsInObject(String model, Integer maxTokens, Double temperature, Double topP, Integer topK, List<String> stopSequences, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) {
        ccs = new CoreCompletionSettings();
        ccs.maxTokens=maxTokens;
        ccs.temperature=temperature;
        ccs.topP=topP;
        ccs.topK = topK;
        ccs.stopSequences= stopSequences;
        ccs.toolChoice = toolChoice;
        ccs.tools = tools;
    }

    public void addMessagesInObject(List<ChatMessage> messages) {
        
        query = OpenAIChatQueryAdapter.adaptForStreaming(OpenAIMode.OPENAI, model, messages, ccs  , false);
        
    }

    public SimpleCompletionResponse sendPostObject(ExternalJSONAPIClient client, HTTPBasedLLMNetworkSettings networkSettings) throws IOException {
        query.stream=false;
        query.streamOptions = null;
        logger.debug("Batch:"+ JSON.pretty(query));
        logger.debug("Endpoint:"+ this.endpoint);
        RawChatCompletionResponse rcr = client.postObjectToJSON(this.endpoint, networkSettings.queryTimeoutMS, RawChatCompletionResponse.class, query);
        if (rcr.choices == null || rcr.choices.size() == 0) {
            throw new IOException("Chat did not respond with valid completion");
        }
        SimpleCompletionResponse ret = new SimpleCompletionResponse();
        ret.text = rcr.choices.get(0).message.content;
        ret.promptTokens = rcr.usage.prompt_tokens;
        ret.completionTokens = rcr.usage.completion_tokens;
        return ret;
    }
//Add in ToolChoice and Tools
    public void streamChatComplete(ExternalJSONAPIClient client, StreamedCompletionResponseConsumer consumer, HTTPBasedLLMNetworkSettings networkSettings, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) throws Exception {
        
        logger.debug("Final Request AdamG:"+ JSON.pretty(query));
        EntityAndRequest ear = client.postJSONToStreamAndRequest(this.endpoint, networkSettings.queryTimeoutMS, query);
        SSEDecoder decoder = new SSEDecoder(ear.entity.getContent());
        consumer.onStreamStarted();
        OpenAIChatResponse.Usage usage = null;
        LLMClient.FinishReason finishReason = null;
        StringBuilder refusalBuilder = null;
        while (true) {
            HTTPSSEEvent event = decoder.next();
           if (logger.isTraceEnabled()) {
                logger.trace("Received raw event from OpenAI: " + JSON.json(event));
            }

            if (event == null || event.data == null) {
                logger.info("End of OpenAI stream");
                break;
            }

            if (event.data.equals("[DONE]")) {
                logger.info("Received explicit end marker from OpenAI stream");
                break;
            }

            OpenAIChatChunkResponse response = JSON.parse(event.data, OpenAIChatChunkResponse.class);
            

            if (response.usage != null) {
                usage = response.usage;
            }

            if (response.choices.isEmpty()) {
                continue; // nothing to do here
            }

            LLMClient.FinishReason reason = GPTAIChatChunkResponseAdapter.extractFinishReason(response);
            if (reason != null) {
                finishReason = reason;
            }

            String refusalChunk = GPTAIChatChunkResponseAdapter.getRefusal(response);
            if (refusalChunk != null) {
                if (refusalBuilder == null) {
                    refusalBuilder = new StringBuilder();
                }
                refusalBuilder.append(refusalChunk);
            }

            if (refusalBuilder == null) {
                StreamedCompletionResponseChunk chunk = GPTAIChatChunkResponseAdapter.adapt(response);
                if (!chunk.isEmpty()) {
                    consumer.onStreamChunk(chunk);
                }
            }
        }

        if (refusalBuilder != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusalBuilder.toString());
            if (usage != null) {
                refusalException.completionTokens = usage.completionTokens;
                refusalException.promptTokens = usage.promptTokens;
                refusalException.totalTokens = usage.totalTokens;
                //refusalException.estimatedCost = model.getEstimatedCompletionCost(refusalException.promptTokens, refusalException.completionTokens);
            }
            throw refusalException;
        }

        StreamedCompletionResponseFooter footer = new StreamedCompletionResponseFooter();
        if (usage != null) {
            footer.completionTokens = usage.completionTokens;
            footer.promptTokens = usage.promptTokens;
            footer.totalTokens = usage.totalTokens;
            //footer.estimatedCost = model.getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
        }

        if (finishReason != null) {
            footer.finishReason = finishReason;
        }

        consumer.onStreamComplete(footer);
    }

    private static final DKULogger logger = DKULogger.getLogger("dku.llm.customplugin.nvidia.nim");
}
