/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.prompts;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.managedfolder.ManagedFolder;
import com.dataiku.dip.managedfolder.ManagedFolderDAO;
import com.dataiku.dip.managedfolder.ManagedFolderHandler;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import com.dataiku.dss.shadelib.org.apache.commons.io.FilenameUtils;
import com.dataiku.dss.shadelib.org.apache.commons.io.IOUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.JsonSyntaxException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class PromptExpander {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private ManagedFolderDAO managedFolderDAO;
    @Autowired
    private VariablesService variablesService;
    private final boolean promptDrivenLLM;
    private final PromptDef promptDef;
    private String singleColumnForUnpromptedMode;
    private final boolean expandVariables = DKUApp.getParams().getBoolParam("dku.llm.expandVariablesInPrompt.enabled", true);
    private AuthCtx authCtx;
    private final String projectKey;
    private volatile VariablesContext vc;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.prompts.expander");

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    private VariablesContext getVariablesContext() {
        if (this.vc != null) {
            return this.vc;
        }
        PromptExpander promptExpander = this;
        synchronized (promptExpander) {
            if (this.vc == null) {
                this.vc = StringUtils.isEmpty((String)this.projectKey) ? this.variablesService.getForGlobal() : this.variablesService.getForProject(this.projectKey);
            }
        }
        return this.vc;
    }

    public PromptExpander(EnrichedLLMStructuredRef enrichedLlmRef, PromptDef promptDef, String projectKey) {
        this.promptDrivenLLM = enrichedLlmRef.promptDriven;
        this.promptDef = promptDef;
        this.projectKey = projectKey;
        SpringUtils.getInstance().autowire((Object)this);
    }

    public PromptExpander(AuthCtx authCtx, String projectKey, boolean promptDrivenLLM, PromptDef promptDef) {
        this.authCtx = authCtx;
        this.projectKey = projectKey;
        this.promptDrivenLLM = promptDrivenLLM;
        this.promptDef = promptDef;
        SpringUtils.getInstance().autowire((Object)this);
    }

    public void setSingleColumnForUnpromptMode(String name) {
        this.singleColumnForUnpromptedMode = name;
    }

    public String getSingleColumnForUnpromptedMode() {
        return this.singleColumnForUnpromptedMode;
    }

    public LLMClient.SingleCompletionQuery expand(ColumnFactory cf, Row row) {
        LLMClient.SingleCompletionQuery recordQuery = new LLMClient.SingleCompletionQuery();
        if (!this.promptDrivenLLM) {
            logger.info((Object)"Expanding row with a non prompt-driven LLM");
            logger.info((Object)("Column: " + this.singleColumnForUnpromptedMode));
            logger.info((Object)("Column value: " + row.get(cf.column(this.singleColumnForUnpromptedMode))));
            recordQuery.messages.add(new LLMClient.ChatMessage("user", row.get(cf.column(this.singleColumnForUnpromptedMode))));
            return recordQuery;
        }
        recordQuery.messages.addAll(this.getPromptAsChatMessages(cf, row));
        logger.trace(() -> "Built record messages set:" + JSON.pretty((Object)recordQuery));
        return recordQuery;
    }

    public List<LLMClient.ChatMessage> getPromptAsChatMessages(ColumnFactory cf, Row row) {
        List<LLMClient.ChatMessage> messages = new ArrayList<LLMClient.ChatMessage>();
        switch (this.promptDef.promptMode) {
            case PROMPT_TEMPLATE_TEXT: {
                messages = this.getTextTemplateMessages(cf, row);
                break;
            }
            case PROMPT_TEMPLATE_STRUCTURED: {
                messages = this.getStructuredTemplateMessages(cf, row);
                break;
            }
            case RAW_PROMPT: 
            case CHAT: {
                throw new IllegalArgumentException("Cannot expand a prompt without inputs");
            }
        }
        return LLMChatMessageUtils.convertPartsToContentMessagesIfPossible(messages);
    }

    private List<LLMClient.ChatMessage> getTextTemplateMessages(ColumnFactory cf, Row row) {
        String userMessage = this.expandVariable(this.promptDef.textPromptTemplate);
        String systemMessage = this.expandVariable(this.promptDef.textPromptSystemTemplate);
        ArrayList<LLMClient.ChatMessage> messages = new ArrayList<LLMClient.ChatMessage>();
        for (PromptStudio.PromptTemplateInput pti : this.promptDef.getInputs()) {
            if (cf == null || row == null) continue;
            String data = this.getInputData(pti, cf, row);
            if (pti.type != PromptStudio.PromptTemplateInputType.TEXT) continue;
            userMessage = userMessage.replace("{{" + pti.name + "}}", data == null ? "" : data);
            if (!StringUtils.isNotBlank((String)systemMessage)) continue;
            systemMessage = systemMessage.replace("{{" + pti.name + "}}", data == null ? "" : data);
        }
        if (StringUtils.isNotBlank((String)systemMessage)) {
            messages.add(new LLMClient.ChatMessage("system", this.splitTemplateIntoParts(systemMessage, cf, row)));
        }
        messages.add(new LLMClient.ChatMessage("user", this.splitTemplateIntoParts(userMessage, cf, row)));
        return messages;
    }

    private List<LLMClient.ChatMessage> getStructuredTemplateMessages(ColumnFactory cf, Row row) {
        PromptStudio.PromptTemplateInput input;
        ArrayList<LLMClient.ChatMessage> messages = new ArrayList<LLMClient.ChatMessage>();
        if (StringUtils.isNotBlank((String)this.promptDef.structuredPromptPrefix)) {
            messages.add(new LLMClient.ChatMessage("system", Collections.singletonList(new LLMClient.ChatMessagePart().withText(this.promptDef.structuredPromptPrefix))));
        }
        for (PromptStudio.StructuredPromptTemplateExample example : this.promptDef.structuredPromptExamples) {
            if (example.inputs.size() != this.promptDef.getInputs().size()) {
                throw new IllegalArgumentException("Example has " + example.inputs.size() + " inputs but prompt has " + this.promptDef.getInputs().size() + " inputs");
            }
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < this.promptDef.getInputs().size(); ++i) {
                input = this.promptDef.getInputs().get(i);
                if (StringUtils.isNotEmpty((String)input.name)) {
                    sb.append(input.name);
                    sb.append(": ");
                }
                sb.append(example.inputs.get(i));
                sb.append("\n");
            }
            messages.add(new LLMClient.ChatMessage("user", Collections.singletonList(new LLMClient.ChatMessagePart().withText(sb.toString()))).withIsPartOfExample());
            messages.add(new LLMClient.ChatMessage("assistant", Collections.singletonList(new LLMClient.ChatMessagePart().withText(example.output))).withIsPartOfExample());
        }
        this.expandVariables(messages);
        ArrayList<LLMClient.ChatMessagePart> parts = new ArrayList<LLMClient.ChatMessagePart>();
        AnyLoc imageFolderLoc = this.getImageFolderLoc();
        for (int idx = 0; idx < this.promptDef.getInputs().size(); ++idx) {
            StringBuilder sb = new StringBuilder();
            input = this.promptDef.getInputs().get(idx);
            if (StringUtils.isNotEmpty((String)input.name)) {
                sb.append(input.name + ": ");
            }
            LLMClient.ChatMessagePart part = new LLMClient.ChatMessagePart();
            if (cf != null && row != null) {
                String data = this.getInputData(input, cf, row);
                if (input.type == PromptStudio.PromptTemplateInputType.IMAGE) {
                    parts.add(new LLMClient.ChatMessagePart().withText(sb.toString()));
                    if (imageFolderLoc == null) {
                        throw new IllegalArgumentException("No image folder selected");
                    }
                    List<String> imagePaths = PromptExpander.convertInputToList(data);
                    for (String imagePath : imagePaths) {
                        String inlineImage = this.getInlineImage(imageFolderLoc.getProjectKey(), imageFolderLoc.getId(), imagePath);
                        String fileExtension = FilenameUtils.getExtension((String)imagePath);
                        parts.add(new LLMClient.ChatMessagePart().withInlineImageFromExtension(inlineImage, fileExtension));
                    }
                    continue;
                }
                sb.append(data);
                parts.add(part.withText(sb.toString()));
                continue;
            }
            sb.append(String.format("{value for input #%s", idx + 1));
            if (StringUtils.isNotEmpty((String)input.name)) {
                sb.append(String.format(" (%s)", input.name));
            }
            sb.append("}");
            parts.add(part.withText(sb.toString()));
        }
        messages.add(new LLMClient.ChatMessage("user", parts));
        return messages;
    }

    private String getInputData(PromptStudio.PromptTemplateInput input, ColumnFactory cf, Row row) {
        if (input == null || cf == null || row == null) {
            return null;
        }
        String inputName = this.promptDef.getInputName(input);
        Column mc = cf.getColumn(inputName);
        if (mc == null) {
            throw new IllegalArgumentException("Input " + input.name + " not found among " + Lists.newArrayList((Iterable)cf.columns()).stream().map(Column::getName).collect(Collectors.joining(",")));
        }
        return row.get(mc);
    }

    @VisibleForTesting
    List<LLMClient.ChatMessagePart> splitTemplateIntoParts(String template, ColumnFactory cf, Row row) {
        Pattern pattern = Pattern.compile("\\{\\{(?<inputName>(?<type>image):[^}]+)\\}\\}");
        Matcher matcher = pattern.matcher(template);
        ArrayList<LLMClient.ChatMessagePart> parts = new ArrayList<LLMClient.ChatMessagePart>();
        AnyLoc imageFolderLoc = this.getImageFolderLoc();
        int pointerIndex = 0;
        while (matcher.find()) {
            int start = matcher.start();
            int end = matcher.end();
            if (start > pointerIndex) {
                parts.add(new LLMClient.ChatMessagePart().withText(template.substring(pointerIndex, start)));
            }
            String type = matcher.group("type");
            String inputName = matcher.group("inputName");
            PromptStudio.PromptTemplateInput input = this.promptDef.getInputByName(inputName);
            if (cf != null && row != null) {
                String data = this.getInputData(input, cf, row);
                if ("image".equals(type)) {
                    if (imageFolderLoc == null) {
                        throw new IllegalArgumentException("No image folder selected");
                    }
                    List<String> imagePaths = PromptExpander.convertInputToList(data);
                    for (String imagePath : imagePaths) {
                        String inlineImage = this.getInlineImage(imageFolderLoc.getProjectKey(), imageFolderLoc.getId(), imagePath);
                        String fileExtension = FilenameUtils.getExtension((String)imagePath);
                        parts.add(new LLMClient.ChatMessagePart().withInlineImageFromExtension(inlineImage, fileExtension));
                    }
                }
            } else {
                parts.add(new LLMClient.ChatMessagePart().withText("{{" + inputName + "}}"));
            }
            pointerIndex = end;
        }
        if (pointerIndex < template.length()) {
            parts.add(new LLMClient.ChatMessagePart().withText(template.substring(pointerIndex)));
        }
        return parts;
    }

    @VisibleForTesting
    AnyLoc getImageFolderLoc() {
        if (this.promptDef.imageFolderId == null) {
            return null;
        }
        return AnyLoc.resolveSmart(this.projectKey, this.promptDef.imageFolderId);
    }

    @VisibleForTesting
    String getInlineImage(String projectKey, String folderId, String imagePath) {
        Callable<InputStream> previewImage = null;
        try (Transaction t = this.transactionService.beginRead();){
            ManagedFolder mf = (ManagedFolder)this.managedFolderDAO.getMandatoryUnsafe(projectKey, folderId);
            ManagedFolderHandler handler = (ManagedFolderHandler)mf.buildHandler(this.authCtx);
            handler.getProvider();
            previewImage = () -> {
                InputStream is = handler.getInputStream(imagePath).rawStream();
                return new ManagedFolderHandler.WrappedInputStream(handler, is);
            };
        }
        catch (Exception e) {
            logger.error((Object)"Unable to create image stream", (Throwable)e);
        }
        if (previewImage != null) {
            String string;
            block18: {
                InputStream picData = (InputStream)previewImage.call();
                try {
                    byte[] sourceBytes = IOUtils.toByteArray((InputStream)picData);
                    string = Base64.getEncoder().encodeToString(sourceBytes);
                    if (picData == null) break block18;
                }
                catch (Throwable throwable) {
                    try {
                        if (picData != null) {
                            try {
                                picData.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (Exception e) {
                        logger.error((Object)"Unable to convert image to base64", (Throwable)e);
                    }
                }
                picData.close();
            }
            return string;
        }
        return "";
    }

    public static List<String> convertInputToList(String input) {
        if (input == null) {
            return Collections.singletonList("");
        }
        if ((input = input.trim()).startsWith("[") && input.endsWith("]")) {
            try {
                return Arrays.asList((String[])new Gson().fromJson(input, String[].class));
            }
            catch (JsonSyntaxException e) {
                return Collections.singletonList(input);
            }
        }
        return Collections.singletonList(input);
    }

    public <T> List<T> expandVariables(List<T> list) {
        if (!this.expandVariables) {
            return list;
        }
        if (this.getVariablesContext() == null) {
            logger.warn((Object)"No VariablesContext available to expand the variables in the prompt");
            return list;
        }
        return list.stream().map(this::expandVariable).collect(Collectors.toList());
    }

    private <T> T expandVariable(T value) {
        if (!this.expandVariables) {
            return value;
        }
        if (this.getVariablesContext() == null) {
            logger.warn((Object)"No VariablesContext available to expand the variables in the prompt");
            return value;
        }
        if (value == null) {
            return null;
        }
        if (value instanceof String) {
            return (T)this.getVariablesContext().expandAllowUnresolved((String)value);
        }
        if (value instanceof LLMClient.ChatMessage) {
            return (T)this.expandChatMessage((LLMClient.ChatMessage)value);
        }
        throw new IllegalArgumentException("Cannot expand variables in the prompt for this type:" + String.valueOf(value.getClass()));
    }

    private LLMClient.ChatMessage expandChatMessage(LLMClient.ChatMessage message) {
        if (message.parts != null) {
            for (LLMClient.ChatMessagePart messagePart : message.parts) {
                messagePart.text = this.getVariablesContext().expandAllowUnresolved(messagePart.text);
            }
        } else {
            String text = message.getText();
            String expandedText = this.getVariablesContext().expandAllowUnresolved(text);
            message.setTextOnly(expandedText);
        }
        return message;
    }
}

