/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.promptstudio;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.cuspol.CustomFieldsService;
import com.dataiku.dip.cuspol.CustomPolicyHooksRegistry;
import com.dataiku.dip.dao.RecipesDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.exec.ContainerRecipeParams;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.promptstudio.PromptResponse;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.llm.promptstudio.PromptStudioDAO;
import com.dataiku.dip.recipes.nlp.prompt.PromptRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.prompt.RawQueryOutputMode;
import com.dataiku.dip.recipes.nlp.prompt.RawResponseOutputMode;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.notifications.backend.TaggableObjectChangedEvent;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TaggableObjectDiffService;
import com.dataiku.dip.server.services.TaggableObjectsService;
import com.dataiku.dip.server.services.TaggingService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.RWTransactionRef;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.google.common.base.Preconditions;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PromptStudiosCRUDService {
    @Autowired
    private PromptStudioDAO dao;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private RecipesDAO recipesDAO;
    @Autowired
    private TaggableObjectsService taggableObjectsService;
    @Autowired
    private CustomPolicyHooksRegistry customPolicyHooksRegistry;
    @Autowired
    private PubSubService pubSub;
    @Autowired
    private TaggingService taggingService;
    @Autowired
    TaggableObjectDiffService taggableObjectDiffService;
    @Autowired
    PubSubService pubSubService;
    @Autowired
    private CustomFieldsService customFieldsService;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.services.promptstudioCRUD");

    public String create(AuthCtx u, String projectKey, String name) throws Exception {
        StringTransmogrifier transmogrifier = new StringTransmogrifier(" ");
        for (PromptStudio head : this.dao.list(projectKey)) {
            transmogrifier.addAlreadyTransmogrifiedAcceptDupes(head.name);
        }
        String id = SecretKeyGenerator.generateSmall();
        PromptStudio studio = new PromptStudio();
        studio.projectKey = projectKey;
        studio.id = id;
        studio.creationTag = new VersionTag(u.getIdentifier());
        studio.versionTag = new VersionTag(u.getIdentifier());
        if (StringUtils.isBlank((String)name)) {
            name = "Untitled studio";
        }
        studio.name = transmogrifier.transmogrify(name);
        this.customFieldsService.enrichWithDefaultCustomFieldsForTaggableObject(studio);
        this.customPolicyHooksRegistry.onPreObjectSave(u, null, studio);
        this.dao.save(studio);
        JsonObject details = new JsonObject();
        details.addProperty("objectDisplayName", name);
        this.pubSub.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.PROMPT_STUDIO, projectKey, id, u, TaggableObjectChangedEvent.ActionType.PROMPT_STUDIO_CREATE).withDetails(details));
        return id;
    }

    public PromptStudio save(PromptStudio studio, boolean summaryOnly) throws IOException, CodedException {
        TaggableObjectChangedEvent.ActionType action;
        Preconditions.checkNotNull((Object)studio.projectKey);
        Preconditions.checkNotNull((Object)studio.id);
        RWTransactionRef t = TransactionContext.retrieveWrite();
        PromptStudio preExisting = (PromptStudio)this.dao.getOrNullUnsafe(studio.projectKey, studio.id);
        this.taggableObjectsService.handleCreationVersionTagOnObjectUpdateNullAllowed(studio, preExisting);
        TaggableObjectDiffService.TaggableObjectsDiff diff = new TaggableObjectDiffService.TaggableObjectsDiff();
        JsonObject details = new JsonObject();
        details.addProperty("objectDisplayName", studio.name);
        if (studio.name != null && preExisting != null && !studio.name.equals(preExisting.name)) {
            action = TaggableObjectChangedEvent.ActionType.PROMPT_STUDIO_RENAME;
            details.addProperty("newName", studio.name);
            details.addProperty("oldName", preExisting.name);
        } else {
            action = TaggableObjectChangedEvent.ActionType.PROMPT_STUDIO_EDIT;
            diff = this.taggableObjectDiffService.diff(preExisting, studio, t.getUser().getIdentifier());
        }
        this.customPolicyHooksRegistry.onPreObjectSave(t.getUser(), (TaggableObjectsService.TaggableObject)this.dao.getOrNull(studio.projectKey, studio.id), studio);
        this.dao.save(studio);
        if (diff.metadataChanged()) {
            this.taggableObjectDiffService.publishAfterTransaction(diff);
        }
        if (!summaryOnly) {
            this.pubSubService.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.PROMPT_STUDIO, studio.projectKey, studio.id, t.getUser(), action).withDetails(details));
        }
        this.taggingService.onObjectSaved(studio.projectKey, studio.tags);
        return studio;
    }

    public PromptStudio copy(AuthCtx user, PromptStudio studio, String projectKey) throws IOException, CodedException {
        String id = SecretKeyGenerator.generateSmall();
        PromptStudio copy = (PromptStudio)JSON.deepCopy((Object)studio);
        copy.projectKey = projectKey;
        copy.id = id;
        copy.versionTag = copy.creationTag = new VersionTag(user.getIdentifier());
        StringTransmogrifier transmogrifier = new StringTransmogrifier(" ");
        for (PromptStudio head : this.dao.list(projectKey)) {
            transmogrifier.addAlreadyTransmogrifiedAcceptDupes(head.name);
        }
        copy.name = transmogrifier.transmogrify("Copy of " + studio.name);
        this.customPolicyHooksRegistry.onPreObjectSave(user, null, copy);
        this.dao.save(copy);
        JsonObject details = new JsonObject();
        details.addProperty("objectDisplayName", copy.name);
        details.addProperty("copy", Boolean.valueOf(true));
        details.addProperty("originalObjectDisplayName", studio.name);
        details.addProperty("originalObjectId", studio.id);
        this.pubSub.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.PROMPT_STUDIO, projectKey, id, user, TaggableObjectChangedEvent.ActionType.PROMPT_STUDIO_CREATE).withDetails(details));
        return copy;
    }

    public PromptStudio.PromptStudioPromptHistory getPromptHistory(String projectKey, String promptStudioId, String promptId) throws IOException {
        assert (projectKey != null);
        assert (promptStudioId != null);
        assert (promptId != null);
        File f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", promptId, "history.json"});
        if (f.isFile()) {
            PromptStudio.PromptStudioPromptHistory psph = (PromptStudio.PromptStudioPromptHistory)JSON.parseFile((File)f, PromptStudio.PromptStudioPromptHistory.class);
            return psph;
        }
        return new PromptStudio.PromptStudioPromptHistory();
    }

    public String addToPromptHistory(AuthCtx authCtx, long time, String projectKey, String promptStudioId, PromptStudio.PromptStudioPrompt prompt) throws IOException {
        PromptStudio.PromptStudioPromptHistory psph = this.getPromptHistory(projectKey, promptStudioId, prompt.id);
        PromptStudio.PromptStudioPromptHistoryEntry entry = new PromptStudio.PromptStudioPromptHistoryEntry();
        entry.promptStudioPrompt = prompt;
        entry.runBy = authCtx.getIdentifier();
        entry.runOn = time;
        entry.runId = SecretKeyGenerator.generate((int)16);
        psph.entries.add(entry);
        File f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", prompt.id, "history.json"});
        JSON.prettyToFile((Object)psph, (File)f);
        return entry.runId;
    }

    public void updatePromptHistory(AuthCtx authCtx, String projectKey, String promptStudioId, String promptId, String promptRunId, PromptResponse response) throws IOException {
        File f;
        PromptStudio.PromptStudioPromptHistory psph = this.getPromptHistory(projectKey, promptStudioId, promptId);
        PromptStudio.PromptStudioPromptHistoryEntry entry = psph.entries.stream().filter(e -> promptRunId.equals(e.runId)).findFirst().orElse(null);
        if (entry != null) {
            entry.responseStats = response.stats;
            f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", promptId, "history.json"});
            JSON.prettyToFile((Object)psph, (File)f);
        } else {
            logger.warnV("Did not find prompt history entry projectKey=%s promptStudioId=%s promptId=%s promptRunId=%s", new Object[]{projectKey, promptStudioId, promptId, promptRunId});
        }
        f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", promptId, "history-responses", "response-" + promptRunId + ".json"});
        JSON.prettyToFile((Object)response, (File)f);
    }

    public PromptResponse getLastResponse(String projectKey, String promptStudioId, String promptId) throws IOException {
        assert (projectKey != null);
        assert (promptStudioId != null);
        assert (promptId != null);
        File f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", promptId, "last-response.json"});
        if (f.isFile()) {
            return (PromptResponse)JSON.parseFile((File)f, PromptResponse.class);
        }
        return new PromptResponse();
    }

    public PromptResponse getHistoryResponse(String projectKey, String promptStudioId, String promptId, String runId) throws IOException {
        assert (projectKey != null);
        assert (promptStudioId != null);
        assert (promptId != null);
        File f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", promptId, "history-responses", "response-" + runId + ".json"});
        if (f.isFile()) {
            return (PromptResponse)JSON.parseFile((File)f, PromptResponse.class);
        }
        return new PromptResponse();
    }

    public PromptStudio revertToPastRun(String projectKey, PromptStudio promptStudio, String promptId, String promptRunId) throws IOException, CodedException {
        List<PromptStudio.PromptStudioPromptHistoryEntry> historyEntries = this.getPromptHistory((String)projectKey, (String)promptStudio.id, (String)promptId).entries;
        PromptStudio.PromptStudioPromptHistoryEntry entry = historyEntries.stream().filter(e -> promptRunId.equals(e.runId)).findFirst().orElse(null);
        if (entry == null) {
            throw new IllegalArgumentException(String.format("Could not find a run with the id %s for the prompt %s", promptRunId, promptId));
        }
        int promptIdx = this.findPromptIdx(promptStudio, promptId);
        promptStudio.prompts.set(promptIdx, entry.promptStudioPrompt);
        this.save(promptStudio, false);
        PromptResponse response = this.getHistoryResponse(projectKey, promptStudio.id, promptId, promptRunId);
        this.saveLastResponse(projectKey, promptStudio.id, promptId, response);
        return promptStudio;
    }

    public int findPromptIdx(PromptStudio promptStudio, String promptId) {
        return IntStream.range(0, promptStudio.prompts.size()).filter(idx -> promptId.equals(promptStudio.prompts.get((int)idx).id)).findFirst().orElseThrow(() -> new IllegalArgumentException(String.format("Could not find the prompt %s in the prompt studio %s (%s)", promptId, promptStudio.name, promptStudio.id)));
    }

    public void saveLastResponse(String projectKey, String promptStudioId, String promptId, PromptResponse response) throws IOException {
        assert (projectKey != null);
        assert (promptStudioId != null);
        assert (promptId != null);
        File f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", promptId, "last-response.json"});
        JSON.prettyToFile((Object)response, (File)f);
    }

    public void deleteLastResponse(String projectKey, String promptStudioId, String promptId) {
        assert (projectKey != null);
        assert (promptStudioId != null);
        assert (promptId != null);
        File f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, promptStudioId, "prompts", promptId, "last-response.json"});
        if (f.exists()) {
            boolean wasDeleted = f.delete();
            if (!wasDeleted) {
                throw new RuntimeException(String.format("Unable to delete last response file of prompt %s in prompt studio %s (project %s): %s", promptStudioId, promptId, projectKey, f.getAbsolutePath()));
            }
        } else {
            throw new RuntimeException("Could not find file containing the last response for this prompt: " + f.getAbsolutePath());
        }
    }

    public void copyChatResponses(AuthCtx authCtx, PromptStudio sourcePromptStudio, PromptStudio copiedPromptStudio) throws IOException {
        String projectKey = sourcePromptStudio.projectKey;
        for (PromptStudio.PromptStudioPrompt psp : sourcePromptStudio.prompts) {
            if (psp.prompt.promptMode != PromptStudio.PromptMode.CHAT) continue;
            String promptId = psp.id;
            if (!copiedPromptStudio.prompts.stream().anyMatch(prompt -> prompt.id.equals(promptId))) continue;
            PromptStudio.PromptStudioPromptHistory history = this.getPromptHistory(projectKey, sourcePromptStudio.id, promptId);
            File f = DKUApp.getFile((String[])new String[]{"prompt-studios", projectKey, copiedPromptStudio.id, "prompts", promptId, "history.json"});
            JSON.prettyToFile((Object)history, (File)f);
            PromptResponse response = this.getLastResponse(projectKey, sourcePromptStudio.id, promptId);
            this.saveLastResponse(projectKey, copiedPromptStudio.id, promptId, response);
        }
    }

    public void delete(AuthCtx liu, String projectKey, String id) throws Exception {
        PromptStudio promptStudio = (PromptStudio)this.dao.getOrNull(projectKey, id);
        this.customPolicyHooksRegistry.onPreObjectDelete(liu, promptStudio);
        JsonObject details = new JsonObject();
        details.addProperty("objectDisplayName", promptStudio.name);
        this.dao.delete(projectKey, id);
        this.pubSub.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.PROMPT_STUDIO, projectKey, id, liu, TaggableObjectChangedEvent.ActionType.PROMPT_STUDIO_DELETE).withDetails(details));
    }

    public EditFromRecipeResult editFromRecipe(AuthCtx u, String projectKey, PromptCreationSettings ps2, PromptRecipePayloadParams payloadParams, SerializedRecipe recipe) throws Exception {
        PromptStudio.PromptStudioPrompt newPromptStudioPrompt = new PromptStudio.PromptStudioPrompt();
        newPromptStudioPrompt.prompt = PromptDef.copy(payloadParams.prompt);
        newPromptStudioPrompt.prompt.promptTemplateQueriesSource = PromptStudio.PromptTemplateQueriesSource.DATASET;
        if (!recipe.getInputsForRole("images").isEmpty()) {
            newPromptStudioPrompt.prompt.imageFolderId = recipe.getInputsForRole((String)"images").get((int)0).ref;
        }
        if (StringUtils.isBlank((String)ps2.promptStudioId)) {
            ps2.promptStudioId = this.create(u, projectKey, ps2.newPromptStudioName);
        }
        PromptStudio promptStudio = (PromptStudio)this.dao.getMandatoryUnsafe(projectKey, ps2.promptStudioId);
        if (StringUtils.isNotBlank((String)ps2.originPromptStudioPromptId)) {
            int promptIdx = this.findPromptIdx(promptStudio, ps2.originPromptStudioPromptId);
            PromptStudio.PromptStudioPrompt oldPrompt = promptStudio.prompts.get(promptIdx);
            promptStudio.prompts.set(promptIdx, newPromptStudioPrompt);
            newPromptStudioPrompt.id = oldPrompt.id;
            newPromptStudioPrompt.nbRows = oldPrompt.nbRows;
            newPromptStudioPrompt.inlinePromptTemplateQueries = oldPrompt.inlinePromptTemplateQueries;
            newPromptStudioPrompt.description = oldPrompt.description;
            newPromptStudioPrompt.tags = oldPrompt.tags;
            newPromptStudioPrompt.starred = oldPrompt.starred;
        } else {
            promptStudio.prompts.add(newPromptStudioPrompt);
            newPromptStudioPrompt.id = SecretKeyGenerator.generate((int)10);
        }
        newPromptStudioPrompt.llmId = payloadParams.llmId;
        newPromptStudioPrompt.llmSettings = payloadParams.completionSettings;
        newPromptStudioPrompt.dataset = recipe.getInputsForRole((String)"main").get((int)0).ref;
        if (recipe.params instanceof ContainerRecipeParams) {
            newPromptStudioPrompt.containerSelection = ((ContainerRecipeParams)recipe.params).containerSelection;
        }
        this.save(promptStudio, false);
        return new EditFromRecipeResult(promptStudio.id, newPromptStudioPrompt.id);
    }

    public EditFromRecipeResult testFromKb(AuthCtx u, String projectKey, String smRef, PromptCreationSettings promptCreationSettings) throws Exception {
        AnyLoc loc = AnyLoc.resolveSmart(projectKey, smRef);
        SavedModel sm = (SavedModel)this.savedModelsDAO.getMandatoryUnsafe(loc);
        String savedModelSmartId = new AnyLoc(sm.projectKey, sm.id).getSmartName(projectKey);
        LLMStructuredRef llmRef = LLMStructuredRef.forRetrievalAugmentedLLM(savedModelSmartId);
        PromptStudio.PromptStudioPrompt newPromptStudioPrompt = new PromptStudio.PromptStudioPrompt();
        newPromptStudioPrompt.id = SecretKeyGenerator.generate((int)10);
        newPromptStudioPrompt.llmId = llmRef.id;
        newPromptStudioPrompt.prompt = new PromptDef();
        newPromptStudioPrompt.prompt.promptMode = PromptStudio.PromptMode.RAW_PROMPT;
        newPromptStudioPrompt.prompt.rawPromptText = "Ask a question to your Retrieval-Augmented LLM.";
        if (StringUtils.isBlank((String)promptCreationSettings.promptStudioId)) {
            promptCreationSettings.promptStudioId = this.create(u, projectKey, promptCreationSettings.newPromptStudioName);
        }
        PromptStudio promptStudio = (PromptStudio)this.dao.getMandatoryUnsafe(projectKey, promptCreationSettings.promptStudioId);
        promptStudio.prompts.add(newPromptStudioPrompt);
        this.save(promptStudio, false);
        return new EditFromRecipeResult(promptStudio.id, newPromptStudioPrompt.id);
    }

    public PromptRecipePayloadParams getRecipePayloadFromStudioPrompt(int promptIdx, PromptStudio promptStudio) {
        PromptStudio.PromptStudioPrompt prompt = promptStudio.prompts.get(promptIdx);
        PromptRecipePayloadParams payload = new PromptRecipePayloadParams();
        payload.llmId = prompt.llmId;
        payload.completionSettings = prompt.llmSettings;
        payload.associatedPromptStudioId = promptStudio.id;
        payload.associatedPromptStudioPromptId = prompt.id;
        payload.rawQueryOutputMode = RawQueryOutputMode.RAW_WITHOUT_FULL_IMAGES;
        payload.rawResponseOutputMode = RawResponseOutputMode.RAW_WITHOUT_TRACES;
        payload.prompt = PromptDef.copy(prompt.prompt);
        payload.prompt.promptTemplateQueriesSource = PromptStudio.PromptTemplateQueriesSource.DATASET;
        return payload;
    }

    public void editRecipe(String projectKey, PromptStudio promptStudio, String promptId, String recipeName) throws Exception {
        int promptIdx = this.findPromptIdx(promptStudio, promptId);
        PromptStudio.PromptStudioPrompt prompt = promptStudio.prompts.get(promptIdx);
        SerializedRecipe recipe = (SerializedRecipe)this.recipesDAO.getMandatoryUnsafe(projectKey, recipeName);
        List<SerializedRecipe.RecipeInput> inputs = recipe.getInputsForRole("main");
        List<SerializedRecipe.RecipeInput> modelInput = recipe.getInputsForRole("model");
        List<SerializedRecipe.RecipeInput> imageInputs = recipe.getInputsForRole("images");
        if (!inputs.isEmpty()) {
            String datasetName = inputs.get((int)0).ref;
            if (!Objects.equals(datasetName, prompt.dataset)) {
                throw new IllegalArgumentException("Target recipe does not use the same dataset as the prompt, cannot overwrite");
            }
        } else {
            throw new IllegalArgumentException("Target recipe does not have an input dataset");
        }
        String imageFolderId = prompt.prompt.imageFolderId;
        if (!imageInputs.isEmpty()) {
            String folderId = imageInputs.get((int)0).ref;
            if (!Objects.equals(folderId, imageFolderId)) {
                throw new IllegalArgumentException("Target recipe does not use the same image folder as the prompt, cannot overwrite");
            }
        } else if (imageFolderId != null && !imageFolderId.isEmpty()) {
            recipe.addInput("images", imageFolderId);
        }
        LLMStructuredRef ref = LLMStructuredRef.decodeId(prompt.llmId);
        if (!modelInput.isEmpty()) {
            String modelId = modelInput.get((int)0).ref;
            if (ref.savedModelSmartId == null) {
                recipe.clearInputsForRole("model");
            } else if (ref.savedModelSmartId != modelId) {
                recipe.clearInputsForRole("model");
                recipe.addInput("model", ref.savedModelSmartId);
            }
        } else if (ref.savedModelSmartId != null) {
            recipe.addInput("model", ref.savedModelSmartId);
        }
        PromptRecipePayloadParams payload = this.getRecipePayloadFromStudioPrompt(promptIdx, promptStudio);
        this.recipesDAO.save(projectKey, recipeName, null, JSON.pretty((Object)payload));
    }

    public void addNewChatMessage(Map<String, PromptStudio.ConversationMessage> messages, PromptStudio.ConversationMessage newMessage) {
        if (messages.isEmpty()) {
            PromptStudio.ConversationMessage parentMessage = new PromptStudio.ConversationMessage();
            messages.put(parentMessage.id, parentMessage);
            newMessage.parentId = parentMessage.id;
        }
        newMessage.version = (int)messages.values().stream().filter(message -> Objects.equals(message.parentId, newMessage.parentId)).count();
        messages.put(newMessage.id, newMessage);
    }

    public static class PromptCreationSettings {
        @Nullable
        String promptStudioId;
        @Nullable
        String newPromptStudioName;
        @Nullable
        String originPromptStudioPromptId;
    }

    public static class EditFromRecipeResult {
        public String promptStudioId;
        public String promptStudioPromptId;

        EditFromRecipeResult(String promptStudioId, String promptStudioPromptId) {
            this.promptStudioId = promptStudioId;
            this.promptStudioPromptId = promptStudioPromptId;
        }
    }
}

