/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.openai.api;

import com.dataiku.common.stereotype.RoutinelyUsedInExtensionCode;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.OpenAIMode;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatQuery;
import com.dataiku.dip.llm.utils.json_schema.JSONSchemaCompatibilityEnhancer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

@RoutinelyUsedInExtensionCode
public class OpenAIChatQueryAdapter {
    private OpenAIChatQueryAdapter() {
    }

    @RoutinelyUsedInExtensionCode
    public static OpenAIChatQuery adapt(OpenAIMode mode, String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs, boolean useMaxCompletionToken) {
        OpenAIChatQuery query = new OpenAIChatQuery();
        query.model = OpenAIChatQueryAdapter.adaptModel(mode, model);
        query.messages = messages.stream().map(OpenAIChatQueryAdapter::adapt).flatMap(Collection::stream).collect(Collectors.toList());
        if (ccs.maxTokens != null) {
            if (useMaxCompletionToken) {
                query.maxCompletionTokens = ccs.maxTokens;
            } else {
                query.maxTokens = ccs.maxTokens;
            }
        }
        if (ccs.temperature != null) {
            query.temperature = ccs.temperature;
        }
        if (ccs.topP != null) {
            query.topP = ccs.topP;
        }
        if (ccs.frequencyPenalty != null) {
            query.frequencyPenalty = ccs.frequencyPenalty;
        }
        if (ccs.presencePenalty != null) {
            query.presencePenalty = ccs.presencePenalty;
        }
        if (ccs.logitBias != null && !ccs.logitBias.isEmpty()) {
            query.logitBias = new HashMap<String, Double>();
            ccs.logitBias.forEach((tokenId, bias) -> query.logitBias.put(tokenId.toString(), (Double)bias));
        }
        if (ccs.logProbs != Boolean.TRUE && ccs.topLogProbs != null) {
            throw new IllegalArgumentException("Setting 'topLogProbs' requires 'logProbs' to be enabled");
        }
        if (ccs.logProbs != null) {
            query.logProbs = ccs.logProbs;
        }
        if (ccs.topLogProbs != null) {
            query.topLogProbs = ccs.topLogProbs;
        }
        if (ccs.stopSequences != null && !ccs.stopSequences.isEmpty()) {
            query.stop = ccs.stopSequences;
        }
        if (ccs.toolChoice != null) {
            query.toolChoice = OpenAIChatQueryAdapter.adaptToolChoice(mode, ccs.toolChoice);
        }
        if (ccs.tools != null && !ccs.tools.isEmpty()) {
            query.tools = ccs.tools.stream().map(OpenAIChatQueryAdapter::adapt).collect(Collectors.toList());
        }
        if (ccs.responseFormat != null) {
            if (ccs.responseFormat instanceof LLMClient.ResponseFormatText) {
                query.responseFormat = new OpenAIChatQuery.OpenAIResponseFormatText();
            } else if (ccs.responseFormat instanceof LLMClient.ResponseFormatJson) {
                LLMClient.ResponseFormatJson responseFormatJson = (LLMClient.ResponseFormatJson)ccs.responseFormat;
                JsonObject jsonSchema = responseFormatJson.schema;
                if (jsonSchema == null) {
                    query.responseFormat = new OpenAIChatQuery.OpenAIResponseFormatJSONObject();
                } else {
                    String title;
                    JsonElement titleElement;
                    OpenAIChatQuery.OpenAIResponseFormatJSONSchema oaiResponseFormat = new OpenAIChatQuery.OpenAIResponseFormatJSONSchema();
                    boolean strict = responseFormatJson.strict == null || responseFormatJson.strict != false;
                    boolean compatible = responseFormatJson.compatible == null ? strict : responseFormatJson.compatible;
                    oaiResponseFormat.jsonSchema.schema = JSONSchemaCompatibilityEnhancer.enhance(jsonSchema, compatible ? JSONSchemaCompatibilityEnhancer.Provider.OPENAI : JSONSchemaCompatibilityEnhancer.Provider.PASSTHROUGH);
                    JsonElement descriptionElement = jsonSchema.get("description");
                    if (descriptionElement != null && descriptionElement.isJsonPrimitive()) {
                        oaiResponseFormat.jsonSchema.description = jsonSchema.get("description").getAsString();
                    }
                    if ((titleElement = jsonSchema.get("title")) != null && titleElement.isJsonPrimitive() && (title = titleElement.getAsString()) != null) {
                        oaiResponseFormat.jsonSchema.name = title.replaceAll("[^a-zA-Z0-9-]", "-");
                    }
                    if (StringUtils.isBlank((String)oaiResponseFormat.jsonSchema.name)) {
                        oaiResponseFormat.jsonSchema.name = "JSON";
                    }
                    oaiResponseFormat.jsonSchema.strict = strict;
                    query.responseFormat = oaiResponseFormat;
                }
            } else {
                throw new IllegalArgumentException(String.format("Unknown response format: %s", ccs.responseFormat.getClass().getSimpleName()));
            }
        }
        return query;
    }

    @RoutinelyUsedInExtensionCode
    public static OpenAIChatQuery adaptForStreaming(OpenAIMode mode, String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs, boolean useMaxCompletionToken) {
        OpenAIChatQuery query = OpenAIChatQueryAdapter.adapt(mode, model, messages, ccs, useMaxCompletionToken);
        query.stream = true;
        query.streamOptions = OpenAIChatQueryAdapter.adaptStreamOptions(mode);
        return query;
    }

    @Nullable
    public static OpenAIChatQuery.StreamOptions adaptStreamOptions(OpenAIMode mode) {
        switch (mode) {
            case AZURE_OPENAI: {
                return null;
            }
            case AZURE_LLM: {
                return null;
            }
            case OPENAI: {
                OpenAIChatQuery.StreamOptions streamOptions = new OpenAIChatQuery.StreamOptions();
                streamOptions.includeUsage = true;
                return streamOptions;
            }
        }
        throw new IllegalArgumentException(String.format("Unknown mode: %s", mode.name()));
    }

    @Nullable
    private static String adaptModel(OpenAIMode mode, String model) {
        switch (mode) {
            case AZURE_OPENAI: 
            case AZURE_LLM: {
                return null;
            }
            case OPENAI: {
                return model;
            }
        }
        throw new IllegalArgumentException(String.format("Unknown mode: %s", mode.name()));
    }

    private static List<OpenAIChatQuery.Message> adapt(LLMClient.ChatMessage message) {
        switch (message.role) {
            case "user": {
                return List.of(OpenAIChatQueryAdapter.adaptUserMessage(message));
            }
            case "system": {
                return List.of(OpenAIChatQueryAdapter.adaptSystemMessage(message));
            }
            case "assistant": {
                return List.of(OpenAIChatQueryAdapter.adaptAssistantMessage(message));
            }
            case "tool": {
                return OpenAIChatQueryAdapter.adaptToolMessage(message);
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported chat message role: %s", message.role));
    }

    private static OpenAIChatQuery.Message adaptUserMessage(LLMClient.ChatMessage message) {
        if (message.isTextOnly()) {
            OpenAIChatQuery.UserMessage msg = new OpenAIChatQuery.UserMessage();
            msg.content = message.getText();
            return msg;
        }
        OpenAIChatQuery.CompositeUserMessage msg = new OpenAIChatQuery.CompositeUserMessage();
        msg.content = message.parts.stream().map(OpenAIChatQueryAdapter::adaptUserMessagePart).collect(Collectors.toList());
        return msg;
    }

    private static OpenAIChatQuery.MessageContentPart adaptUserMessagePart(LLMClient.ChatMessagePart part) {
        switch (part.type) {
            case IMAGE_INLINE: {
                String mime = StringUtils.isBlank((String)part.imageMimeType) ? "image/jpeg" : part.imageMimeType;
                String url = String.format("data:%s;base64,%s", mime, part.inlineImage);
                OpenAIChatQuery.ImageContentPart icp = new OpenAIChatQuery.ImageContentPart();
                icp.imageUrl = new OpenAIChatQuery.ImageURL();
                icp.imageUrl.url = url;
                return icp;
            }
            case IMAGE_URI: {
                OpenAIChatQuery.ImageContentPart icp = new OpenAIChatQuery.ImageContentPart();
                icp.imageUrl = new OpenAIChatQuery.ImageURL();
                icp.imageUrl.url = part.imageUrl;
                return icp;
            }
            case TEXT: {
                OpenAIChatQuery.TextContentPart tcp = new OpenAIChatQuery.TextContentPart();
                tcp.text = part.text;
                return tcp;
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported user chat message part type: %s", new Object[]{part.type}));
    }

    private static OpenAIChatQuery.Message adaptSystemMessage(LLMClient.ChatMessage message) {
        if (message.isTextOnly()) {
            OpenAIChatQuery.SystemMessage msg = new OpenAIChatQuery.SystemMessage();
            msg.content = message.getText();
            return msg;
        }
        throw new IllegalArgumentException(String.format("Chat message with role %s must be text-only", message.role));
    }

    private static OpenAIChatQuery.Message adaptAssistantMessage(LLMClient.ChatMessage message) {
        if (!message.isTextOnly()) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must be text-only", message.role));
        }
        OpenAIChatQuery.AssistantMessage msg = new OpenAIChatQuery.AssistantMessage();
        msg.content = message.getText();
        if (message.toolCalls != null) {
            msg.toolCalls = message.toolCalls.stream().map(OpenAIChatQueryAdapter::adapt).collect(Collectors.toList());
        }
        return msg;
    }

    private static List<OpenAIChatQuery.Message> adaptToolMessage(LLMClient.ChatMessage message) {
        if (!message.isTextOnly()) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must be text-only", message.role));
        }
        if (message.toolOutputs == null) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must have tool outputs", message.role));
        }
        return message.toolOutputs.stream().map(OpenAIChatQueryAdapter::adapt).collect(Collectors.toList());
    }

    private static OpenAIChatQuery.ToolMessage adapt(LLMClient.ToolOutput to) {
        OpenAIChatQuery.ToolMessage tm = new OpenAIChatQuery.ToolMessage();
        tm.toolCallId = to.callId;
        tm.content = to.output;
        return tm;
    }

    private static OpenAIChatQuery.ToolCall adapt(LLMClient.AbstractToolCall atc) {
        if (atc instanceof LLMClient.FunctionToolCall) {
            LLMClient.FunctionToolCall ftc = (LLMClient.FunctionToolCall)atc;
            OpenAIChatQuery.ToolCall tc = new OpenAIChatQuery.ToolCall();
            tc.id = ftc.id;
            tc.function = new OpenAIChatQuery.ToolCallFunction();
            tc.function.name = ftc.function.name;
            tc.function.arguments = ftc.function.arguments;
            return tc;
        }
        throw new IllegalArgumentException(String.format("Unknown tool call: %s", atc.getClass().getSimpleName()));
    }

    private static OpenAIChatQuery.ToolChoice adaptToolChoice(OpenAIMode mode, LLMClient.ToolChoice choice) {
        switch (mode) {
            case AZURE_OPENAI: 
            case OPENAI: {
                return OpenAIChatQueryAdapter.adaptOpenAIToolChoice(choice);
            }
            case AZURE_LLM: {
                return OpenAIChatQueryAdapter.adaptAzureLLMToolChoice(choice);
            }
        }
        throw new IllegalArgumentException(String.format("Unknown mode: %s", mode.name()));
    }

    private static OpenAIChatQuery.ToolChoice adaptOpenAIToolChoice(LLMClient.ToolChoice choice) {
        if (choice instanceof LLMClient.NoneToolChoice) {
            return OpenAIChatQuery.ConstantToolChoice.none;
        }
        if (choice instanceof LLMClient.RequiredToolChoice) {
            return OpenAIChatQuery.ConstantToolChoice.required;
        }
        if (choice instanceof LLMClient.AutoToolChoice) {
            return OpenAIChatQuery.ConstantToolChoice.auto;
        }
        if (choice instanceof LLMClient.NamedToolChoice) {
            LLMClient.NamedToolChoice ntc = (LLMClient.NamedToolChoice)choice;
            OpenAIChatQuery.EnforcedToolChoice etc = new OpenAIChatQuery.EnforcedToolChoice();
            etc.function = new OpenAIChatQuery.ToolChoiceFunction();
            etc.function.name = ntc.name;
            return etc;
        }
        throw new IllegalArgumentException(String.format("Unknown tool choice: %s", choice.getClass().getSimpleName()));
    }

    private static OpenAIChatQuery.ToolChoice adaptAzureLLMToolChoice(LLMClient.ToolChoice choice) {
        if (choice instanceof LLMClient.NoneToolChoice) {
            return OpenAIChatQuery.ConstantToolChoice.none;
        }
        if (choice instanceof LLMClient.RequiredToolChoice) {
            return OpenAIChatQuery.ConstantToolChoice.any;
        }
        if (choice instanceof LLMClient.AutoToolChoice) {
            return OpenAIChatQuery.ConstantToolChoice.auto;
        }
        if (choice instanceof LLMClient.NamedToolChoice) {
            LLMClient.NamedToolChoice ntc = (LLMClient.NamedToolChoice)choice;
            OpenAIChatQuery.EnforcedToolChoice etc = new OpenAIChatQuery.EnforcedToolChoice();
            etc.function = new OpenAIChatQuery.ToolChoiceFunction();
            etc.function.name = ntc.name;
            return etc;
        }
        throw new IllegalArgumentException(String.format("Unknown tool choice: %s", choice.getClass().getSimpleName()));
    }

    private static OpenAIChatQuery.FunctionTool adapt(LLMClient.AbstractTool tool) {
        if (tool instanceof LLMClient.FunctionTool) {
            LLMClient.FunctionTool fDesc = (LLMClient.FunctionTool)tool;
            OpenAIChatQuery.FunctionTool ft = new OpenAIChatQuery.FunctionTool();
            ft.function = new OpenAIChatQuery.FunctionToolDesc();
            ft.function.name = fDesc.function.name;
            ft.function.description = fDesc.function.description;
            ft.function.parameters = fDesc.function.getParameters();
            return ft;
        }
        throw new IllegalArgumentException(String.format("Unknown tool: %s", tool.getClass().getSimpleName()));
    }
}

