/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.mistralai.api;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatQuery;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public class MistralAIChatQueryAdapter {
    private MistralAIChatQueryAdapter() {
    }

    public static MistralAIChatQuery adapt(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        MistralAIChatQuery query = MistralAIChatQueryAdapter.adapt(messages, ccs);
        query.model = model;
        return query;
    }

    public static MistralAIChatQuery adapt(List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        MistralAIChatQuery query = new MistralAIChatQuery();
        query.messages = messages.stream().map(MistralAIChatQueryAdapter::adapt).flatMap(Collection::stream).collect(Collectors.toList());
        if (ccs.maxTokens != null) {
            query.maxTokens = ccs.maxTokens;
        }
        if (ccs.temperature != null) {
            query.temperature = ccs.temperature;
        }
        if (ccs.responseFormat instanceof LLMClient.ResponseFormatJson) {
            query.responseFormat = new MistralAIChatQuery.MistralAIResponseFormatJSONObject();
        }
        if (ccs.topP != null) {
            query.topP = ccs.topP;
        }
        if (ccs.toolChoice != null) {
            query.toolChoice = MistralAIChatQueryAdapter.adapt(ccs.toolChoice);
        }
        if (ccs.tools != null && !ccs.tools.isEmpty()) {
            query.tools = ccs.tools.stream().map(MistralAIChatQueryAdapter::adapt).collect(Collectors.toList());
        }
        return query;
    }

    public static MistralAIChatQuery adaptForStream(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        MistralAIChatQuery query = MistralAIChatQueryAdapter.adapt(model, messages, ccs);
        query.stream = true;
        return query;
    }

    private static List<MistralAIChatQuery.Message> adapt(LLMClient.ChatMessage message) {
        switch (message.role) {
            case "user": {
                return List.of(MistralAIChatQueryAdapter.adaptUserMessage(message));
            }
            case "system": {
                return List.of(MistralAIChatQueryAdapter.adaptSystemMessage(message));
            }
            case "assistant": {
                return List.of(MistralAIChatQueryAdapter.adaptAssistantMessage(message));
            }
            case "tool": {
                return MistralAIChatQueryAdapter.adaptToolMessage(message);
            }
            case "memoryFragment": {
                throw new IllegalArgumentException("Unsupported chat message role: memoryFragment. You probably switched from a reasoning model to a non-reasoning model which is not supported.");
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported chat message role: %s", message.role));
    }

    private static MistralAIChatQuery.Message adaptUserMessage(LLMClient.ChatMessage message) {
        if (message.isTextOnly()) {
            MistralAIChatQuery.UserMessage msg = new MistralAIChatQuery.UserMessage();
            msg.content = message.getText();
            return msg;
        }
        throw new IllegalArgumentException(String.format("Chat message with role %s must be text-only", message.role));
    }

    private static MistralAIChatQuery.Message adaptSystemMessage(LLMClient.ChatMessage message) {
        if (message.isTextOnly()) {
            MistralAIChatQuery.SystemMessage msg = new MistralAIChatQuery.SystemMessage();
            msg.content = message.getText();
            return msg;
        }
        throw new IllegalArgumentException(String.format("Chat message with role %s must be text-only", message.role));
    }

    private static MistralAIChatQuery.Message adaptAssistantMessage(LLMClient.ChatMessage message) {
        if (!message.isTextOnly()) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must be text-only", message.role));
        }
        MistralAIChatQuery.AssistantMessage msg = new MistralAIChatQuery.AssistantMessage();
        msg.content = Objects.requireNonNullElse(message.getText(), "");
        if (message.toolCalls != null) {
            msg.toolCalls = message.toolCalls.stream().map(MistralAIChatQueryAdapter::adapt).collect(Collectors.toList());
        }
        return msg;
    }

    private static List<MistralAIChatQuery.Message> adaptToolMessage(LLMClient.ChatMessage message) {
        if (!message.isTextOnly()) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must be text-only", message.role));
        }
        if (message.toolOutputs == null) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must have tool outputs", message.role));
        }
        return message.toolOutputs.stream().map(MistralAIChatQueryAdapter::adapt).collect(Collectors.toList());
    }

    private static MistralAIChatQuery.ToolMessage adapt(LLMClient.ToolOutput to) {
        MistralAIChatQuery.ToolMessage tm = new MistralAIChatQuery.ToolMessage();
        tm.toolCallId = to.callId;
        tm.content = to.output;
        return tm;
    }

    private static MistralAIChatQuery.ToolCall adapt(LLMClient.AbstractToolCall atc) {
        if (atc instanceof LLMClient.FunctionToolCall) {
            LLMClient.FunctionToolCall ftc = (LLMClient.FunctionToolCall)atc;
            MistralAIChatQuery.ToolCall tc = new MistralAIChatQuery.ToolCall();
            tc.id = ftc.id;
            tc.function = new MistralAIChatQuery.ToolCallFunction();
            tc.function.name = ftc.function.name;
            tc.function.arguments = ftc.function.arguments;
            return tc;
        }
        throw new IllegalArgumentException(String.format("Unknown tool call: %s", atc.getClass().getSimpleName()));
    }

    private static MistralAIChatQuery.ToolChoice adapt(LLMClient.ToolChoice choice) {
        if (choice instanceof LLMClient.NoneToolChoice) {
            return MistralAIChatQuery.ToolChoice.none;
        }
        if (choice instanceof LLMClient.AutoToolChoice) {
            return MistralAIChatQuery.ToolChoice.auto;
        }
        if (choice instanceof LLMClient.RequiredToolChoice) {
            return MistralAIChatQuery.ToolChoice.any;
        }
        if (choice instanceof LLMClient.NamedToolChoice) {
            throw new IllegalArgumentException("The MistralAI chat completion API does not support tool choice: \"tool_name\"");
        }
        throw new IllegalArgumentException(String.format("Unknown tool choice: %s", choice.getClass().getSimpleName()));
    }

    private static MistralAIChatQuery.FunctionTool adapt(LLMClient.AbstractTool tool) {
        if (tool instanceof LLMClient.FunctionTool) {
            LLMClient.FunctionTool fDesc = (LLMClient.FunctionTool)tool;
            MistralAIChatQuery.FunctionTool ft = new MistralAIChatQuery.FunctionTool();
            ft.function = new MistralAIChatQuery.FunctionToolDesc();
            ft.function.name = fDesc.function.name;
            ft.function.description = fDesc.function.description;
            ft.function.parameters = fDesc.function.getParameters();
            return ft;
        }
        throw new IllegalArgumentException(String.format("Unknown tool: %s", tool.getClass().getSimpleName()));
    }
}

