/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.bedrock.converse;

import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.bedrock.RawBedrockClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.llm.utils.json_schema.JSONSchemaCompatibilityEnhancer;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.SdkBytes;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.document.Document;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.protocols.json.internal.unmarshall.document.DocumentUnmarshaller;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.protocols.jsoncore.JsonNodeParser;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.protocols.jsoncore.JsonNodeVisitor;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ImageFormat;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ImageSource;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.Message;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ReasoningContentBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ReasoningTextBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.Tool;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.utils.builder.SdkBuilder;
import com.google.gson.JsonObject;
import com.google.gson.JsonSyntaxException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class ConverseChatQueryAdapter {
    private ConverseChatQueryAdapter() {
    }

    public static ConverseRequest adapt(List<LLMClient.ChatMessage> messages, RawBedrockClient.Settings settings, CoreCompletionSettings ccs, AuthCtx authCtx) {
        ConverseRequest.Builder queryBuilder = ConverseRequest.builder();
        queryBuilder.system(ConverseChatQueryAdapter.toSystemMessages(messages)).messages(ConverseChatQueryAdapter.toMessages(messages, authCtx)).inferenceConfig(ConverseChatQueryAdapter.getInferenceConfiguration(ccs)).modelId(settings.modelId);
        Document additionalFields = ConverseChatQueryAdapter.adaptAdditionalFields(ccs, settings.handling);
        if (!additionalFields.isNull()) {
            queryBuilder.additionalModelRequestFields(additionalFields);
        }
        if (ccs.tools != null && !ccs.tools.isEmpty()) {
            queryBuilder.toolConfig(ConverseChatQueryAdapter.adaptToolConfig(ccs, settings.handling));
        }
        if (settings.hasGuardrail()) {
            queryBuilder.guardrailConfig(builder -> builder.guardrailIdentifier(settings.guardrailIdentifier).guardrailVersion(settings.guardrailVersion));
        }
        return (ConverseRequest)queryBuilder.build();
    }

    public static ConverseStreamRequest adaptStreaming(List<LLMClient.ChatMessage> messages, RawBedrockClient.Settings settings, CoreCompletionSettings ccs, AuthCtx authCtx) {
        ConverseStreamRequest.Builder queryBuilder = ConverseStreamRequest.builder();
        queryBuilder.system(ConverseChatQueryAdapter.toSystemMessages(messages)).messages(ConverseChatQueryAdapter.toMessages(messages, authCtx)).inferenceConfig(ConverseChatQueryAdapter.getInferenceConfiguration(ccs)).modelId(settings.modelId);
        Document additionalFields = ConverseChatQueryAdapter.adaptAdditionalFields(ccs, settings.handling);
        if (!additionalFields.isNull()) {
            queryBuilder.additionalModelRequestFields(additionalFields);
        }
        if (ccs.tools != null && !ccs.tools.isEmpty()) {
            queryBuilder.toolConfig(ConverseChatQueryAdapter.adaptToolConfig(ccs, settings.handling));
        }
        if (settings.hasGuardrail()) {
            queryBuilder.guardrailConfig(builder -> builder.guardrailIdentifier(settings.guardrailIdentifier).guardrailVersion(settings.guardrailVersion));
        }
        return (ConverseStreamRequest)queryBuilder.build();
    }

    private static List<SystemContentBlock> toSystemMessages(List<LLMClient.ChatMessage> messages) {
        return messages.stream().filter(m -> "system".equals(m.role)).map(LLMClient.ChatMessage::getText).map(SystemContentBlock::fromText).collect(Collectors.toList());
    }

    private static InferenceConfiguration getInferenceConfiguration(CoreCompletionSettings ccs) {
        InferenceConfiguration.Builder cfgBuilder = InferenceConfiguration.builder();
        if (ccs.maxTokens != null) {
            cfgBuilder.maxTokens(ccs.maxTokens);
        }
        if (ccs.stopSequences != null && !ccs.stopSequences.isEmpty()) {
            cfgBuilder.stopSequences(ccs.stopSequences);
        }
        if (ccs.temperature != null) {
            cfgBuilder.temperature(Float.valueOf(ccs.temperature.floatValue()));
        }
        if (ccs.topP != null) {
            cfgBuilder.topP(Float.valueOf(ccs.topP.floatValue()));
        }
        return (InferenceConfiguration)cfgBuilder.build();
    }

    private static Document adaptAdditionalFields(CoreCompletionSettings ccs, GenericLLMHandling handling) {
        Document.MapBuilder fieldsBuilder = Document.mapBuilder();
        switch (handling) {
            case AMAZON_NOVA: {
                if (ccs.topK == null) break;
                fieldsBuilder.putNumber("topK", ccs.topK.intValue());
                break;
            }
            case ANTHROPIC_CLAUDE_CHAT: 
            case ANTHROPIC_CLAUDE: 
            case MISTRAL_AI: {
                if (ccs.topK == null) break;
                fieldsBuilder.putNumber("top_k", ccs.topK.intValue());
                break;
            }
            case COHERE_COMMAND_CHAT: {
                if (ccs.topK != null) {
                    fieldsBuilder.putNumber("k", ccs.topK.intValue());
                }
                if (ccs.frequencyPenalty != null) {
                    fieldsBuilder.putNumber("frequency_penalty", ccs.frequencyPenalty.doubleValue());
                }
                if (ccs.presencePenalty == null) break;
                fieldsBuilder.putNumber("presence_penalty", ccs.presencePenalty.doubleValue());
            }
        }
        if (ccs.reasoningEffort != null && ccs.reasoningEffort != LLMClient.ReasoningEffort.OFF) {
            if (ccs.reasoningEffort == LLMClient.ReasoningEffort.STANDARD) {
                fieldsBuilder.putDocument("thinking", Document.mapBuilder().putString("type", "enabled").putNumber("budget_tokens", 8000).build());
            } else if (ccs.reasoningEffort == LLMClient.ReasoningEffort.CUSTOM) {
                if (StringUtils.isEmpty((String)ccs.customReasoningEffort)) {
                    throw new IllegalArgumentException(String.format("Custom reasoning effort cannot be null or empty: %s ", ccs.customReasoningEffort));
                }
                try {
                    int budgetTokens = Integer.parseInt(ccs.customReasoningEffort);
                    fieldsBuilder.putDocument("thinking", Document.mapBuilder().putString("type", "enabled").putNumber("budget_tokens", budgetTokens).build());
                }
                catch (Exception e) {
                    throw new IllegalArgumentException(String.format("Invalid 'customReasoningEffort' for Bedrock: Expected an integer (number of tokens for reasoning), but received '%s'.", ccs.customReasoningEffort));
                }
            } else {
                throw new IllegalArgumentException(String.format("Unsupported Reasoning Effort for Bedrock: expected OFF, STANDARD or CUSTOM but received %s.", new Object[]{ccs.reasoningEffort}));
            }
        }
        return fieldsBuilder.build();
    }

    private static List<Message> toMessages(List<LLMClient.ChatMessage> messages, AuthCtx authCtx) {
        ArrayList<Message> result = new ArrayList<Message>();
        ReasoningContentBlock pendingReasoning = null;
        block12: for (LLMClient.ChatMessage message : messages) {
            switch (message.role) {
                case "system": {
                    continue block12;
                }
                case "memoryFragment": {
                    pendingReasoning = ConverseChatQueryAdapter.createReasoningBlock(message);
                    continue block12;
                }
                case "user": {
                    result.add(ConverseChatQueryAdapter.buildMessage(message, ConversationRole.USER, null, authCtx));
                    continue block12;
                }
                case "assistant": {
                    result.add(ConverseChatQueryAdapter.buildMessage(message, ConversationRole.ASSISTANT, pendingReasoning, authCtx));
                    pendingReasoning = null;
                    continue block12;
                }
            }
            throw new IllegalArgumentException("Unsupported chat message role: " + message.role);
        }
        return result;
    }

    private static ReasoningContentBlock createReasoningBlock(LLMClient.ChatMessage message) {
        if (message.memoryFragment == null || message.memoryFragment.llmReasoning == null) {
            throw new IllegalArgumentException("Invalid thinking message format : the thinking message is null.");
        }
        JsonObject data = message.memoryFragment.llmReasoning;
        String signature = data.has("signature") ? data.get("signature").getAsString() : "";
        String text = data.has("text") ? data.get("text").getAsString() : "";
        return ReasoningContentBlock.fromReasoningText((ReasoningTextBlock)((ReasoningTextBlock)ReasoningTextBlock.builder().signature(signature).text(text).build()));
    }

    private static Message buildMessage(LLMClient.ChatMessage message, ConversationRole role, ReasoningContentBlock reasoning, AuthCtx authCtx) {
        ArrayList<ContentBlock> content = new ArrayList<ContentBlock>();
        if (reasoning != null) {
            content.add(ContentBlock.fromReasoningContent((ReasoningContentBlock)reasoning));
        }
        if (message.isTextOnly()) {
            String text = message.getText();
            if (text != null) {
                content.add(ContentBlock.fromText((String)text));
            }
        } else if (message.parts != null) {
            List<ContentBlock> messageParts = message.parts.stream().map(ConverseChatQueryAdapter::adapt).toList();
            content.addAll(messageParts);
        }
        if (message.toolCalls != null && !message.toolCalls.isEmpty()) {
            List<ContentBlock> toolCalls = message.toolCalls.stream().map(ConverseChatQueryAdapter::adapt).toList();
            content.addAll(toolCalls);
        }
        if (message.toolOutputs != null && !message.toolOutputs.isEmpty()) {
            List<ContentBlock> toolOutputs = message.toolOutputs.stream().map(output -> ConverseChatQueryAdapter.adapt(output, authCtx)).toList();
            content.addAll(toolOutputs);
        }
        return (Message)Message.builder().role(role).content(content).build();
    }

    private static ContentBlock adapt(LLMClient.ToolOutput toolOutput, AuthCtx authCtx) {
        ToolResultBlock.Builder toolResultBuilder = ToolResultBlock.builder().toolUseId(toolOutput.callId);
        ArrayList<ToolResultContentBlock> blocks = new ArrayList<ToolResultContentBlock>();
        try {
            JsonObject toolOutputObj = (JsonObject)JSON.parse((String)toolOutput.output, JsonObject.class);
            if (toolOutputObj != null) {
                Document jsonContent = ConverseChatQueryAdapter.toDocument(toolOutputObj.toString());
                ToolResultContentBlock contentBlock = ToolResultContentBlock.fromJson((Document)jsonContent);
                blocks.add(contentBlock);
            }
        }
        catch (JsonSyntaxException e1) {
            ToolResultContentBlock contentBlock = ToolResultContentBlock.fromText((String)toolOutput.output);
            blocks.add(contentBlock);
        }
        List<LLMClient.ChatMessagePart> parts = LLMChatMessageUtils.getPartsFromToolOutput(authCtx, toolOutput, true);
        if (!parts.isEmpty()) {
            blocks.addAll(parts.stream().map(ConverseChatQueryAdapter::adaptToolResult).toList());
        }
        toolResultBuilder.content(blocks);
        return ContentBlock.fromToolResult((ToolResultBlock)((ToolResultBlock)toolResultBuilder.build()));
    }

    private static ContentBlock adapt(LLMClient.AbstractToolCall abstractToolCall) {
        Document toolUseInput;
        if (!(abstractToolCall instanceof LLMClient.FunctionToolCall)) {
            throw new NotImplementedException("New tool call class has not been implemented for Bedrock Converse API");
        }
        LLMClient.FunctionToolCall functionToolCall = (LLMClient.FunctionToolCall)abstractToolCall;
        ToolUseBlock.Builder toolUseBuilder = ToolUseBlock.builder().toolUseId(functionToolCall.id).name(functionToolCall.function.name);
        if (functionToolCall.function.arguments != null) {
            try {
                toolUseInput = ConverseChatQueryAdapter.toDocument(functionToolCall.function.arguments);
            }
            catch (UncheckedIOException e) {
                toolUseInput = Document.fromString((String)functionToolCall.function.arguments);
            }
        } else {
            toolUseInput = ConverseChatQueryAdapter.toDocument("{}");
        }
        toolUseBuilder.input(toolUseInput);
        return ContentBlock.fromToolUse((ToolUseBlock)((ToolUseBlock)toolUseBuilder.build()));
    }

    private static ContentBlock adapt(LLMClient.ChatMessagePart part) {
        switch (part.type) {
            case TEXT: {
                return ContentBlock.fromText((String)part.text);
            }
            case IMAGE_INLINE: {
                if (part.inlineImage != null) {
                    byte[] decodedBytes = Base64.getDecoder().decode(part.inlineImage);
                    ImageBlock imageBlock = (ImageBlock)ImageBlock.builder().source(ImageSource.fromBytes((SdkBytes)SdkBytes.fromByteArray((byte[])decodedBytes))).format(ConverseChatQueryAdapter.adaptImageFormat(part.imageMimeType)).build();
                    return ContentBlock.fromImage((ImageBlock)imageBlock);
                }
                throw new IllegalArgumentException("Message type was inline image, but no inline image was given.");
            }
            case IMAGE_URI: {
                throw new IllegalArgumentException("Image URIs not supported for Bedrock Converse API models. Use inline images instead.");
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported assistant message part type: %s", new Object[]{part.type}));
    }

    private static ToolResultContentBlock adaptToolResult(LLMClient.ChatMessagePart part) {
        if (part.type == LLMClient.ChatMessagePartType.TEXT) {
            return ToolResultContentBlock.fromText((String)part.text);
        }
        if (part.containsImageData()) {
            LLMClient.Base64Image image = part.getBase64Image();
            byte[] decodedBytes = Base64.getDecoder().decode(image.base64Data);
            ImageBlock imageBlock = (ImageBlock)ImageBlock.builder().source(ImageSource.fromBytes((SdkBytes)SdkBytes.fromByteArray((byte[])decodedBytes))).format(ConverseChatQueryAdapter.adaptImageFormat(image.mimeType)).build();
            return ToolResultContentBlock.fromImage((ImageBlock)imageBlock);
        }
        throw new IllegalArgumentException(String.format("Unsupported multipart tool output part type: %s", new Object[]{part.type}));
    }

    private static ImageFormat adaptImageFormat(@Nullable String imageMimeType) {
        if (imageMimeType == null) {
            return ImageFormat.JPEG;
        }
        switch (imageMimeType) {
            case "image/png": {
                return ImageFormat.PNG;
            }
            case "image/jpeg": {
                return ImageFormat.JPEG;
            }
            case "image/gif": {
                return ImageFormat.GIF;
            }
            case "image/webp": {
                return ImageFormat.WEBP;
            }
        }
        throw new IllegalArgumentException(String.format("Unsupported image mime type: %s", imageMimeType));
    }

    private static ToolConfiguration adaptToolConfig(CoreCompletionSettings ccs, GenericLLMHandling handling) {
        ToolConfiguration.Builder toolConfigBuilder = ToolConfiguration.builder();
        LLMClient.ToolChoice toolChoice = ccs.toolChoice;
        ArrayList<Tool> adaptedTools = new ArrayList<Tool>(ccs.tools.size());
        for (LLMClient.AbstractTool abstractTool : ccs.tools) {
            if (!(abstractTool instanceof LLMClient.FunctionTool)) {
                throw new NotImplementedException("New tool class has not been implemented for Bedrock Converse API");
            }
            LLMClient.FunctionTool functionTool = (LLMClient.FunctionTool)abstractTool;
            JsonObject schema = functionTool.function.getParameters();
            schema = JSONSchemaCompatibilityEnhancer.enhance(schema, handling == GenericLLMHandling.AMAZON_NOVA ? JSONSchemaCompatibilityEnhancer.Provider.NOVA : JSONSchemaCompatibilityEnhancer.Provider.PASSTHROUGH);
            Document inputJson = ConverseChatQueryAdapter.toDocument(JSON.json((Object)schema));
            Tool newTool = (Tool)Tool.builder().toolSpec(builder -> builder.name(functionTool.function.name).description(StringUtils.isNotBlank((String)functionTool.function.description) ? functionTool.function.description : functionTool.function.name).inputSchema(ToolInputSchema.fromJson((Document)inputJson))).build();
            adaptedTools.add(newTool);
        }
        if (!adaptedTools.isEmpty()) {
            toolConfigBuilder.tools(adaptedTools);
        }
        if (toolChoice != null) {
            toolConfigBuilder.toolChoice(builder -> {
                if (toolChoice instanceof LLMClient.AutoToolChoice) {
                    builder.auto(SdkBuilder::build);
                } else if (toolChoice instanceof LLMClient.RequiredToolChoice) {
                    builder.any(SdkBuilder::build);
                } else if (toolChoice instanceof LLMClient.NamedToolChoice) {
                    builder.tool(b -> b.name(((LLMClient.NamedToolChoice)toolChoice).name));
                } else if (toolChoice instanceof LLMClient.NoneToolChoice) {
                    throw new IllegalArgumentException("The Bedrock Converse API does not support tool choice: \"none\"");
                }
            });
        }
        return (ToolConfiguration)toolConfigBuilder.build();
    }

    private static Document toDocument(String jsonString) {
        return (Document)JsonNodeParser.create().parse(jsonString).visit((JsonNodeVisitor)new DocumentUnmarshaller());
    }
}

