/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.prompt;

import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.exec.joinlike.VirtualInputBasedRecipeCreationService;
import com.dataiku.dip.datasets.ManagedDatasetsHelper;
import com.dataiku.dip.labeling.ImageViewSettings;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.recipes.RecipeMeta;
import com.dataiku.dip.recipes.common.RecipeCreator;
import com.dataiku.dip.recipes.common.SISORecipeCreator;
import com.dataiku.dip.recipes.nlp.prompt.PromptRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.prompt.PromptRecipeSchemaComputer;
import com.dataiku.dip.recipes.nlp.prompt.RawQueryOutputMode;
import com.dataiku.dip.recipes.nlp.prompt.RawResponseOutputMode;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.datasets.DatasetSaveService;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.WithMessages;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.List;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;

public class PromptRecipeCreator
extends SISORecipeCreator {
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private DatasetSaveService datasetSaveService;
    @Autowired
    private VirtualInputBasedRecipeCreationService creationService;

    public PromptRecipeCreator(AuthCtx authCtx, RecipeMeta meta) {
        super(authCtx, meta);
    }

    @Override
    protected void setOutputSchema(SerializedRecipe recipe, String payload, Dataset outputDataset) throws Exception {
        ManagedDatasetsHelper.copySchema(this.authCtx, this.getInputDataset(recipe).getSchema(), outputDataset);
        Schema schema = outputDataset.getSchema();
        for (PromptRecipeSchemaComputer.PromptRecipeColumn column : PromptRecipeSchemaComputer.PromptRecipeColumn.values()) {
            schema.withColumn(column.name, column.type);
        }
    }

    @Override
    public RecipeCreator.CreationResult create_NT(SerializedRecipe recipe, JsonObject creationData) throws Exception {
        RecipeCreator.CreationResult creationResult = super.create_NT(recipe, creationData);
        if (!recipe.getInputsForRole("images").isEmpty()) {
            SerializedDataset outputSD;
            String imageFolderId = recipe.getInputsForRole((String)"images").get((int)0).ref;
            try (Transaction t = this.transactionService.retrieveOrBeginRead();){
                outputSD = (SerializedDataset)this.datasetsDAO.getMandatory(AnyLoc.resolveSmart(recipe.projectKey, recipe.getOutputsForRole((String)"main").get((int)0).ref));
            }
            JsonElement prompt = creationData.get("initialPayload").getAsJsonObject().get("prompt");
            PromptDef promptDef = (PromptDef)JSON.parse((JsonElement)prompt, PromptDef.class);
            List<PromptStudio.PromptTemplateInput> promptTemplateInputList = promptDef.getInputs();
            Optional<PromptStudio.PromptTemplateInput> imageInput = promptTemplateInputList.stream().filter(input -> input.type == PromptStudio.PromptTemplateInputType.IMAGE).findFirst();
            if (imageInput.isPresent() && outputSD != null) {
                ImageViewSettings imageViewSettings = new ImageViewSettings();
                imageViewSettings.enabled = true;
                imageViewSettings.managedFolderSmartId = imageFolderId;
                imageViewSettings.pathColumn = imageInput.get().datasetColumnName;
                outputSD.imageViewSettings = imageViewSettings;
                try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(this.authCtx);){
                    WithMessages<SerializedDataset> datasetWithMessages = this.datasetSaveService.save(outputSD.getProjectKey(), outputSD.getId(), outputSD, t.getUser());
                    creationResult.messages.mergeFrom(datasetWithMessages.getMessages());
                    t.commitV("Saved output dataset '%s' in project '%s'", new Object[]{outputSD.getId(), outputSD.getProjectKey()});
                }
            }
        }
        return creationResult;
    }

    @Override
    protected String makeInitialPayload(SerializedRecipe recipe, JsonObject data, Dataset outputDataset) throws Exception {
        return this.makeInitialPayload(data);
    }

    @Override
    protected String makeInitialPayload(SerializedRecipe recipe, Dataset inputDataset, Dataset outputDataset, JsonObject data) {
        return this.makeInitialPayload(data);
    }

    private String makeInitialPayload(JsonObject data) {
        JsonElement initialPayload = data.get("initialPayload");
        if (initialPayload != null) {
            return JSON.pretty((Object)initialPayload);
        }
        PromptRecipePayloadParams params = new PromptRecipePayloadParams();
        params.rawQueryOutputMode = RawQueryOutputMode.RAW_WITHOUT_FULL_IMAGES;
        params.rawResponseOutputMode = RawResponseOutputMode.RAW_WITHOUT_TRACES;
        return JSON.json((Object)params);
    }
}

