/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.snowflakecortex;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.snowflakecortex.SnowflakeCompletionRESTQuery;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;

public class SnowflakeCompletionRESTQueryAdapter {
    private SnowflakeCompletionRESTQueryAdapter() {
    }

    public static SnowflakeCompletionRESTQuery adaptForStreaming(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        SnowflakeCompletionRESTQuery query = SnowflakeCompletionRESTQueryAdapter.adapt(model, messages, ccs);
        query.stream = true;
        return query;
    }

    public static SnowflakeCompletionRESTQuery adapt(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        SnowflakeCompletionRESTQuery query = new SnowflakeCompletionRESTQuery();
        query.model = model;
        query.messages = messages.stream().map(msg -> SnowflakeCompletionRESTQueryAdapter.adapt(msg, messages)).flatMap(Collection::stream).collect(Collectors.toList());
        if (ccs.maxTokens != null) {
            query.maxTokens = ccs.maxTokens;
        }
        if (ccs.temperature != null) {
            query.temperature = ccs.temperature;
        }
        if (ccs.topP != null) {
            query.topP = ccs.topP;
        }
        if (ccs.toolChoice != null) {
            query.toolChoice = SnowflakeCompletionRESTQueryAdapter.adapt(ccs.toolChoice);
        }
        if (ccs.tools != null && !ccs.tools.isEmpty()) {
            query.tools = ccs.tools.stream().map(SnowflakeCompletionRESTQueryAdapter::adapt).collect(Collectors.toList());
            if (query.toolChoice == null) {
                query.toolChoice = new SnowflakeCompletionRESTQuery.ToolChoice();
                query.toolChoice.type = SnowflakeCompletionRESTQuery.ToolChoiceType.auto;
            }
        }
        return query;
    }

    private static List<SnowflakeCompletionRESTQuery.Message> adapt(LLMClient.ChatMessage message, List<LLMClient.ChatMessage> messages) {
        switch (message.role) {
            case "user": {
                return List.of(SnowflakeCompletionRESTQueryAdapter.adaptUserMessage(message));
            }
            case "system": {
                return List.of(SnowflakeCompletionRESTQueryAdapter.adaptSystemMessage(message));
            }
            case "assistant": {
                return List.of(SnowflakeCompletionRESTQueryAdapter.adaptAssistantMessage(message));
            }
            case "tool": {
                return List.of(SnowflakeCompletionRESTQueryAdapter.adaptToolMessage(message, messages));
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported chat message role: %s", message.role));
    }

    private static SnowflakeCompletionRESTQuery.Message adaptUserMessage(LLMClient.ChatMessage message) {
        if (message.isTextOnly()) {
            SnowflakeCompletionRESTQuery.UserMessage msg = new SnowflakeCompletionRESTQuery.UserMessage();
            msg.content = message.getText();
            return msg;
        }
        SnowflakeCompletionRESTQuery.CompositeUserMessage msg = new SnowflakeCompletionRESTQuery.CompositeUserMessage();
        msg.contentList = message.parts.stream().map(SnowflakeCompletionRESTQueryAdapter::adaptUserMessagePart).collect(Collectors.toList());
        return msg;
    }

    private static SnowflakeCompletionRESTQuery.UserContentPart adaptUserMessagePart(LLMClient.ChatMessagePart part) {
        switch (part.type) {
            case IMAGE_INLINE: {
                SnowflakeCompletionRESTQuery.ImageContentPart icp = new SnowflakeCompletionRESTQuery.ImageContentPart();
                String mime = StringUtils.isBlank((String)part.imageMimeType) ? "image/jpeg" : part.imageMimeType;
                SnowflakeCompletionRESTQuery.ImageContentDetails icd = new SnowflakeCompletionRESTQuery.ImageContentDetails();
                icd.contentType = mime;
                icd.content = part.inlineImage;
                icp.details = icd;
                return icp;
            }
            case IMAGE_URI: {
                throw new IllegalArgumentException("Image URIs not supported by Snowflake Cortex connection. Use inline images instead.");
            }
            case TEXT: {
                SnowflakeCompletionRESTQuery.TextContentPart tcp = new SnowflakeCompletionRESTQuery.TextContentPart();
                tcp.text = part.text;
                return tcp;
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported user chat message part type: %s", new Object[]{part.type}));
    }

    private static SnowflakeCompletionRESTQuery.Message adaptSystemMessage(LLMClient.ChatMessage message) {
        if (message.isTextOnly()) {
            SnowflakeCompletionRESTQuery.SystemMessage msg = new SnowflakeCompletionRESTQuery.SystemMessage();
            msg.content = message.getText();
            return msg;
        }
        throw new IllegalArgumentException(String.format("Chat message with role %s must be text-only", message.role));
    }

    private static SnowflakeCompletionRESTQuery.Message adaptAssistantMessage(LLMClient.ChatMessage message) {
        if (!message.isTextOnly()) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must be text-only", message.role));
        }
        SnowflakeCompletionRESTQuery.AssistantMessage msg = new SnowflakeCompletionRESTQuery.AssistantMessage();
        msg.content = Objects.requireNonNullElse(message.getText(), "Tool call");
        if (message.toolCalls != null) {
            msg.contentList = message.toolCalls.stream().map(SnowflakeCompletionRESTQueryAdapter::adapt).collect(Collectors.toList());
        }
        return msg;
    }

    private static SnowflakeCompletionRESTQuery.Message adaptToolMessage(LLMClient.ChatMessage message, List<LLMClient.ChatMessage> messages) {
        if (!message.isTextOnly()) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must be text-only", message.role));
        }
        if (message.toolOutputs == null) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must have tool outputs", message.role));
        }
        int idxToolResult = messages.indexOf(message);
        if (idxToolResult < 1) {
            throw new IllegalArgumentException("Tool result with out prior assistant message");
        }
        LLMClient.ChatMessage assistantMessage = messages.get(idxToolResult - 1);
        if (!org.apache.commons.lang3.StringUtils.equals((CharSequence)assistantMessage.role, (CharSequence)"assistant")) {
            throw new IllegalArgumentException("Message before tool results should have role assistant");
        }
        if (CollectionUtils.isEmpty(assistantMessage.toolCalls)) {
            throw new IllegalArgumentException("Assistant message before tool results should include tool calls");
        }
        if (!assistantMessage.toolCalls.stream().allMatch(tc -> tc instanceof LLMClient.FunctionToolCall)) {
            throw new IllegalArgumentException("Assistant message before tool results should include function tool calls");
        }
        if (CollectionUtils.size(assistantMessage.toolCalls) != CollectionUtils.size(message.toolOutputs)) {
            throw new IllegalArgumentException("Different count of toolCalls and toolOutputs");
        }
        List<String> functionNames = assistantMessage.toolCalls.stream().map(tc -> ((LLMClient.FunctionToolCall)tc).function.name).collect(Collectors.toList());
        return SnowflakeCompletionRESTQueryAdapter.adapt(message.toolOutputs, functionNames);
    }

    private static SnowflakeCompletionRESTQuery.ToolMessage adapt(List<LLMClient.ToolOutput> tos, List<String> functionNames) {
        SnowflakeCompletionRESTQuery.ToolMessage tm = new SnowflakeCompletionRESTQuery.ToolMessage();
        tm.contentList = new ArrayList<SnowflakeCompletionRESTQuery.ToolResultWrapper>();
        for (int i = 0; i < tos.size(); ++i) {
            LLMClient.ToolOutput to = tos.get(i);
            SnowflakeCompletionRESTQuery.ToolResultWrapper trw = new SnowflakeCompletionRESTQuery.ToolResultWrapper();
            trw.result = new SnowflakeCompletionRESTQuery.ToolResult();
            trw.result.toolUseId = to.callId;
            trw.result.name = functionNames.get(i);
            SnowflakeCompletionRESTQuery.ToolResultContent trc = new SnowflakeCompletionRESTQuery.ToolResultContent();
            trc.text = to.output;
            trw.result.content.add(trc);
            tm.contentList.add(trw);
        }
        return tm;
    }

    private static SnowflakeCompletionRESTQuery.ToolUseWrapper adapt(LLMClient.AbstractToolCall atc) {
        if (atc instanceof LLMClient.FunctionToolCall) {
            LLMClient.FunctionToolCall ftc = (LLMClient.FunctionToolCall)atc;
            SnowflakeCompletionRESTQuery.ToolUseWrapper tc = new SnowflakeCompletionRESTQuery.ToolUseWrapper();
            tc.toolUse.toolUseId = ftc.id;
            tc.toolUse.name = ftc.function.name;
            try {
                tc.toolUse.input = (JsonObject)JSON.parse((String)ftc.function.arguments, JsonObject.class);
            }
            catch (Exception e) {
                throw new IllegalArgumentException(String.format("Error while parsing as JSON arguments '%s' for function '%s'", ftc.function.arguments, ftc.function.name));
            }
            return tc;
        }
        throw new IllegalArgumentException(String.format("Unknown tool call: %s", atc.getClass().getSimpleName()));
    }

    private static SnowflakeCompletionRESTQuery.ToolChoice adapt(LLMClient.ToolChoice choice) {
        if (choice instanceof LLMClient.NoneToolChoice) {
            return null;
        }
        SnowflakeCompletionRESTQuery.ToolChoice ret = new SnowflakeCompletionRESTQuery.ToolChoice();
        if (choice instanceof LLMClient.AutoToolChoice) {
            ret.type = SnowflakeCompletionRESTQuery.ToolChoiceType.auto;
            return ret;
        }
        if (choice instanceof LLMClient.RequiredToolChoice) {
            ret.type = SnowflakeCompletionRESTQuery.ToolChoiceType.required;
            return ret;
        }
        if (choice instanceof LLMClient.NamedToolChoice) {
            ret.type = SnowflakeCompletionRESTQuery.ToolChoiceType.tool;
            return ret;
        }
        throw new IllegalArgumentException(String.format("Unknown tool choice: %s", choice.getClass().getSimpleName()));
    }

    private static SnowflakeCompletionRESTQuery.FunctionTool adapt(LLMClient.AbstractTool tool) {
        if (tool instanceof LLMClient.FunctionTool) {
            LLMClient.FunctionTool fDesc = (LLMClient.FunctionTool)tool;
            SnowflakeCompletionRESTQuery.FunctionTool t = new SnowflakeCompletionRESTQuery.FunctionTool();
            t.toolSpec.type = "generic";
            t.toolSpec.name = fDesc.function.name;
            t.toolSpec.description = fDesc.function.description;
            t.toolSpec.inputSchema = fDesc.function.getParameters();
            return t;
        }
        throw new IllegalArgumentException(String.format("Unknown tool: %s", tool.getClass().getSimpleName()));
    }
}

