/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.prompt;

import com.dataiku.dip.agents.tools.filtering.SimpleFilter;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.exec.AbstractInitializedRunner;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.prompts.PromptExpander;
import com.dataiku.dip.llm.promptstudio.PromptResponse;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.llm.retrieval.LangchainBasedRAGServer;
import com.dataiku.dip.recipes.nlp.common.NLPLLMRecipeRunnerBase;
import com.dataiku.dip.recipes.nlp.common.NLPRecipeParallelRunInputFeedThread;
import com.dataiku.dip.recipes.nlp.prompt.PromptRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.prompt.PromptRecipeSchemaComputer;
import com.dataiku.dip.recipes.nlp.prompt.RawQueryOutputMode;
import com.dataiku.dip.server.services.AuditPrivilegedClient;
import com.dataiku.dip.shaker.mrimpl.formats.RowWithFactories;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.ObjectUtils;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.dataiku.dip.warnings.WarningsContext;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;

public class PromptRecipeRunner
extends NLPLLMRecipeRunnerBase {
    private PromptRecipePayloadParams desc;
    static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.prompt");

    public PromptRecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (PromptRecipePayloadParams)JSON.parse((String)payload, PromptRecipePayloadParams.class);
        this.setLlmId(this.desc.llmId);
    }

    @Override
    public void run() throws Exception {
        logger.info((Object)"Prompt recipe API runner started");
        if (!this.desc.prompt.promptMode.canUseDataset) {
            throw new IllegalArgumentException("Cannot use a non-template prompt with an input dataset");
        }
        if (StringUtils.isBlank((CharSequence)this.desc.llmId)) {
            throw new IllegalArgumentException("No LLM was specified");
        }
        AbstractInitializedRunner.Output output = (AbstractInitializedRunner.Output)((List)this.outputs.get("main")).get(0);
        StringTransmogrifier transmogrifier = this.getOutputColTransmogrifier();
        Columns columns = new Columns();
        columns.validationStatus = output.cf.column(transmogrifier.transmogrify(PromptRecipeSchemaComputer.PromptRecipeColumn.LLM_VALIDATION_STATUS.name));
        columns.formattedLLMOutput = output.cf.column(transmogrifier.transmogrify(PromptRecipeSchemaComputer.PromptRecipeColumn.LLM_OUTPUT.name));
        columns.rawLLMQuery = output.cf.column(transmogrifier.transmogrify(PromptRecipeSchemaComputer.PromptRecipeColumn.LLM_RAW_QUERY.name));
        columns.rawLLMOutput = output.cf.column(transmogrifier.transmogrify(PromptRecipeSchemaComputer.PromptRecipeColumn.LLM_RAW_OUTPUT.name));
        columns.errorMessage = output.cf.column(transmogrifier.transmogrify(PromptRecipeSchemaComputer.PromptRecipeColumn.LLM_ERROR_MSG.name));
        try (AuditPrivilegedClient auditClient = new AuditPrivilegedClient();){
            ProcessorOutputToSIP processorOutput = new ProcessorOutputToSIP(output.out);
            try (CompletionRecipeLLMMeshClient meshClient = this.buildCompletionRecipeClient(this.desc.prompt.guardrailsPipelineSettings);){
                this.enrichedLLMRef = meshClient.getEnrichedRef();
                this.plcStream = meshClient.completeQueriesAsyncStream(this.desc.completionSettings.toFullSettings());
                PromptExpander pe = new PromptExpander(this.authCtxService.getAuthCtx(), this.recipe.getProjectKey(), this.enrichedLLMRef.promptDriven, this.desc.prompt);
                if (!this.enrichedLLMRef.promptDriven) {
                    logger.info((Object)("Setting single column " + this.desc.prompt.singleInputColumn));
                    pe.setSingleColumnForUnpromptMode(this.desc.prompt.singleInputColumn);
                }
                List<String> imageColumns = null;
                boolean shouldReplaceBase64ImagesByRefs = this.desc.prompt.imageFolderId != null && this.desc.rawQueryOutputMode == RawQueryOutputMode.RAW_WITHOUT_FULL_IMAGES;
                String imageFolderFullId = null;
                if (shouldReplaceBase64ImagesByRefs) {
                    imageColumns = this.desc.prompt.getInputs().stream().filter(i -> i.type == PromptStudio.PromptTemplateInputType.IMAGE).map(i -> i.datasetColumnName).collect(Collectors.toList());
                    imageFolderFullId = AnyLoc.resolveSmart(this.recipe.getProjectKey(), this.desc.prompt.imageFolderId).getFullName();
                }
                InputFeedThread ift = new InputFeedThread((ColumnFactory)output.cf, (RowFactory)output.rf, pe);
                ift.start();
                while (true) {
                    logger.info((Object)"Fetching next response from PLCS");
                    Optional o = this.plcStream.fetchNextResponse();
                    logger.info((Object)"Fetched response from PLCS");
                    if (!o.isPresent()) break;
                    Row outputRow = output.rf.row();
                    for (Map.Entry e : ((Map)((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).context).entrySet()) {
                        outputRow.put((Column)output.cf.column((String)e.getKey()), (String)e.getValue());
                    }
                    PromptResponse.SingleInputPromptResponse sipr = new PromptResponse.SingleInputPromptResponse();
                    LLMClient.SimpleCompletionResponseOrError scr = ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).scr;
                    if (scr.ok) {
                        sipr.promptTokens = scr.promptTokens;
                        sipr.completionTokens = scr.completionTokens;
                        sipr.totalTokens = scr.totalTokens;
                        sipr.tokenCountsAreEstimated = scr.tokenCountsAreEstimated;
                        sipr.estimatedCost = scr.estimatedCost;
                        sipr.totalUsage = scr.totalUsage;
                        sipr.sources = scr.sources;
                        sipr.artifacts = scr.artifacts;
                        sipr.additionalInformation = scr.additionalInformation;
                        if (scr.toolValidationRequests != null && !scr.toolValidationRequests.isEmpty()) {
                            sipr.error = true;
                            sipr.llmError = "Tool call validation received. Validating tool calls is only supported in chat mode.";
                            this.activity.warnContext.addWarning(WarningsContext.WarningType.LLM_QUERY_ERROR, sipr.llmError, logger);
                        } else {
                            sipr.error = false;
                            sipr.rawLLMOutput = scr.text;
                            sipr.validate(this.desc.prompt.resultValidation);
                        }
                    } else {
                        sipr.error = true;
                        sipr.llmError = scr.errorMessage;
                        this.activity.warnContext.addWarning(WarningsContext.WarningType.LLM_QUERY_ERROR, scr.errorMessage, logger);
                    }
                    List<LLMClient.ImageRefExcerpt> imageRefs = null;
                    if (shouldReplaceBase64ImagesByRefs) {
                        imageRefs = this.getImageRefs(imageFolderFullId, imageColumns, (Map)((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).context);
                    }
                    this.writeResponseInOutputRow(sipr, (CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get(), outputRow, columns, shouldReplaceBase64ImagesByRefs, imageRefs);
                    processorOutput.emitRow(outputRow);
                    LLMAuditHelper.emitLLMCompletionAuditFromJobIfNeeded(this.authCtx, auditClient, this.enrichedLLMRef, meshClient.getConnection(), ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).completionQuery, scr);
                }
                logger.info((Object)"Terminated");
                processorOutput.lastRowEmitted();
                ift.join();
                if (ift.getException() != null) {
                    throw new IOException("Input feeding failed", ift.getException());
                }
                this.handleCRU(meshClient);
            }
        }
    }

    private void writeResponseInOutputRow(PromptResponse.SingleInputPromptResponse response, CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext full_response, Row outputRow, Columns columns, boolean shouldReplaceBase64ImagesByRefs, @Nullable List<LLMClient.ImageRefExcerpt> imageRefs) {
        if (response.error) {
            outputRow.put(columns.errorMessage, response.llmError);
        } else {
            switch (response.validationStatus) {
                case NOT_PERFORMED: {
                    outputRow.put(columns.formattedLLMOutput, response.formattedOutput == null ? response.rawLLMOutput : response.formattedOutput);
                    break;
                }
                case VALID: {
                    outputRow.put(columns.formattedLLMOutput, response.formattedOutput);
                    outputRow.put(columns.validationStatus, response.validationStatus.toString());
                    break;
                }
                case INVALID: {
                    outputRow.put(columns.errorMessage, response.validationMessage);
                    outputRow.put(columns.validationStatus, response.validationStatus.toString());
                }
            }
        }
        switch (this.desc.rawQueryOutputMode) {
            case RAW: {
                outputRow.put(columns.rawLLMQuery, JSON.json((Object)full_response.completionQuery));
                break;
            }
            case RAW_WITHOUT_FULL_IMAGES: {
                if (shouldReplaceBase64ImagesByRefs && imageRefs != null) {
                    LLMClient.SingleCompletionQuery queryWithImageRefs = PromptRecipeRunner.replaceBase64ImagesByRefs(full_response.completionQuery, imageRefs);
                    outputRow.put(columns.rawLLMQuery, JSON.json((Object)queryWithImageRefs));
                    break;
                }
                outputRow.put(columns.rawLLMQuery, JSON.json((Object)full_response.completionQuery));
                break;
            }
        }
        switch (this.desc.rawResponseOutputMode) {
            case RAW: {
                outputRow.put(columns.rawLLMOutput, JSON.json((Object)full_response.scr));
                break;
            }
            case RAW_WITHOUT_TRACES: {
                LLMClient.SimpleCompletionResponseOrError scrCopy = (LLMClient.SimpleCompletionResponseOrError)ObjectUtils.shallowCopy((Object)full_response.scr);
                scrCopy.trace = null;
                outputRow.put(columns.rawLLMOutput, JSON.json((Object)scrCopy));
                break;
            }
            case NONE: {
                if (response.validationStatus != PromptResponse.ResponseValidationStatus.INVALID) break;
                outputRow.put(columns.rawLLMOutput, response.rawLLMOutput);
            }
        }
    }

    private List<LLMClient.ImageRefExcerpt> getImageRefs(String fullFolderId, List<String> imageColumns, Map<String, String> context) {
        ArrayList<LLMClient.ImageRefExcerpt> result = new ArrayList<LLMClient.ImageRefExcerpt>();
        for (String col : imageColumns) {
            String input = context.get(col);
            for (String imagePath : PromptExpander.convertInputToList(input)) {
                LLMClient.ImageRefExcerpt ref = new LLMClient.ImageRefExcerpt(fullFolderId, imagePath);
                result.add(ref);
            }
        }
        return result;
    }

    private static LLMClient.SingleCompletionQuery replaceBase64ImagesByRefs(LLMClient.SingleCompletionQuery completionQuery, List<LLMClient.ImageRefExcerpt> imageRefs) {
        LLMClient.SingleCompletionQuery queryWithImageRefs = new LLMClient.SingleCompletionQuery();
        queryWithImageRefs.context = completionQuery.context;
        queryWithImageRefs.messages = new ArrayList<LLMClient.ChatMessage>();
        int nextImage = 0;
        for (LLMClient.ChatMessage message : completionQuery.messages) {
            LLMClient.ChatMessage modifiedMessage = new LLMClient.ChatMessage(message);
            if (message.parts != null) {
                modifiedMessage.parts = new ArrayList<LLMClient.ChatMessagePart>();
                Iterator<LLMClient.ChatMessagePart> iterator = message.parts.iterator();
                while (iterator.hasNext()) {
                    LLMClient.ChatMessagePart part;
                    LLMClient.ChatMessagePart partToAdd = part = iterator.next();
                    if (part.type == LLMClient.ChatMessagePartType.IMAGE_INLINE) {
                        if (nextImage < imageRefs.size()) {
                            ImageRefChatMessagePart partWithImageRef = new ImageRefChatMessagePart();
                            partWithImageRef.type = LLMClient.ChatMessagePartType.IMAGE_REF;
                            partWithImageRef.imageRef = imageRefs.get(nextImage++);
                            partToAdd = partWithImageRef;
                        } else {
                            logger.error((Object)"Can't find image ref to replace base64 image. Keeping original raw query.");
                            return completionQuery;
                        }
                    }
                    modifiedMessage.parts.add(partToAdd);
                }
            }
            queryWithImageRefs.messages.add(modifiedMessage);
        }
        if (nextImage != imageRefs.size()) {
            logger.error((Object)"Some image refs were not used to replace base64 images. Keeping original raw query.");
            return completionQuery;
        }
        return queryWithImageRefs;
    }

    @Override
    public void notifyBeforeAborting() {
    }

    private static class Columns {
        Column rawLLMOutput;
        Column rawLLMQuery;
        Column formattedLLMOutput;
        Column validationStatus;
        Column errorMessage;

        private Columns() {
        }
    }

    private class InputFeedThread
    extends NLPRecipeParallelRunInputFeedThread {
        private final PromptExpander pe;

        InputFeedThread(ColumnFactory cf, RowFactory rf, PromptExpander pe) throws IOException {
            super(PromptRecipeRunner.this.authCtx, PromptRecipeRunner.this.recipe, PromptRecipeRunner.this.activity, PromptRecipeRunner.this.plcStream, cf, rf);
            this.pe = pe;
        }

        @Override
        public LLMClient.SingleCompletionQuery buildCompletionQuery(Row row) {
            LLMClient.SingleCompletionQuery filterableCompletionQuery = this.pe.expand(this.cf, row);
            if (PromptRecipeRunner.this.desc.performFiltering && PromptRecipeRunner.this.desc.filter != null) {
                if (filterableCompletionQuery.context == null) {
                    filterableCompletionQuery.context = new JsonObject();
                }
                LangchainBasedRAGServer.RagQueryFilter ragQueryFilter = new LangchainBasedRAGServer.RagQueryFilter();
                ragQueryFilter.filter = SimpleFilter.fromComplexFilter(PromptRecipeRunner.this.desc.filter, Optional.of(new RowWithFactories(this.cf, this.rf, row)));
                JsonArray jsonArray = new JsonArray();
                jsonArray.add(JSON.toJsonElement((Object)ragQueryFilter));
                filterableCompletionQuery.context.add("callerFilters", JSON.toJsonElement((Object)jsonArray));
            }
            return filterableCompletionQuery;
        }
    }

    private static class ImageRefChatMessagePart
    extends LLMClient.ChatMessagePart {
        @Nullable
        public LLMClient.ImageRefExcerpt imageRef;
    }
}

