/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.promptstudio;

import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datasets.DatasetSelectionToMemTable;
import com.dataiku.dip.datasets.SamplingParam;
import com.dataiku.dip.datasets.SingleThreadPusherToMemTable;
import com.dataiku.dip.exceptions.DSSInternalErrorException;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.LLMMeshClient;
import com.dataiku.dip.llm.online.LLMMeshClientFactory;
import com.dataiku.dip.llm.prompts.PromptExpander;
import com.dataiku.dip.llm.promptstudio.PromptResponse;
import com.dataiku.dip.llm.promptstudio.PromptResponsePreview;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.llm.promptstudio.PromptStudiosCRUDService;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.streaming.endpoints.httpsse.MiniSSEEmitter;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.lang3.StringUtils;
import jakarta.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.json.JSONObject;
import org.springframework.beans.factory.annotation.Autowired;

public class PromptExecutionEngine {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private ComputeResourceUsageReportingService cruReportingService;
    @Autowired
    private PromptStudiosCRUDService promptStudiosCrudService;
    @Autowired
    private LLMRefEnricherService llmRefEnricherService;
    private final AuthCtx authCtx;
    private final PromptStudio promptStudio;
    private final PromptStudio.PromptStudioPrompt promptStudioPrompt;
    private final PromptDef promptDef;
    private String promptRunId;
    private MemTable forcedRecordsForPromptTemplate;
    private static final String DKU_SINGLE_INLINE_INPUT = "__dku_single_inline_input__";
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.promptstudio.engine");

    public PromptExecutionEngine(AuthCtx authCtx, PromptStudio promptStudio, PromptStudio.PromptStudioPrompt promptStudioPrompt) {
        this.authCtx = authCtx;
        this.promptStudioPrompt = promptStudioPrompt;
        this.promptDef = promptStudioPrompt.prompt;
        this.promptStudio = promptStudio;
        SpringUtils.getInstance().autowire((Object)this);
        logger.info((Object)"PEE initialized");
    }

    public void setPromptRunId(String promptRunId) {
        this.promptRunId = promptRunId;
    }

    public void setForcedRecordsForPromptTemplate(MemTable mt) {
        this.forcedRecordsForPromptTemplate = mt;
    }

    public PromptResponse respond() throws Exception {
        EnrichedLLMStructuredRef enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.getProjectKey());
        logger.info((Object)"Starting to respond to prompt");
        ArrayList<LLMClient.SingleCompletionQuery> in = new ArrayList<LLMClient.SingleCompletionQuery>();
        MemTable recordsFromPromptTemplate = null;
        PromptExpander pe = new PromptExpander(this.authCtx, this.promptStudio.projectKey, enrichedLLMRef.promptDriven, this.promptDef);
        switch (this.promptDef.promptMode) {
            case PROMPT_TEMPLATE_TEXT: 
            case PROMPT_TEMPLATE_STRUCTURED: {
                if (!enrichedLLMRef.promptDriven) {
                    if (this.promptDef.promptTemplateQueriesSource == PromptStudio.PromptTemplateQueriesSource.INLINE) {
                        pe.setSingleColumnForUnpromptMode(DKU_SINGLE_INLINE_INPUT);
                    } else {
                        pe.setSingleColumnForUnpromptMode(this.promptDef.singleInputColumn);
                    }
                }
                recordsFromPromptTemplate = this.getRecordsForPromptTemplate();
                logger.info((Object)("Columns in records: " + recordsFromPromptTemplate.columns.values().stream().map(c2 -> c2.getName()).collect(Collectors.joining(","))));
                for (int recordIdx = 0; recordIdx < recordsFromPromptTemplate.nrows(); ++recordIdx) {
                    in.add(pe.expand(recordsFromPromptTemplate, recordsFromPromptTemplate.rows.get(recordIdx)));
                }
                break;
            }
            case RAW_PROMPT: {
                if (!enrichedLLMRef.promptDriven) {
                    throw new IllegalArgumentException("Cannot run a single-shot prompt with a non-promptable LLM");
                }
                LLMClient.SingleCompletionQuery recordQuery = new LLMClient.SingleCompletionQuery();
                recordQuery.messages.add(new LLMClient.ChatMessage("user", this.promptDef.rawPromptText));
                pe.expandVariables(recordQuery.messages);
                in.add(recordQuery);
                break;
            }
            case CHAT: {
                throw new DSSInternalErrorException("Prompt mode CHAT is not supported");
            }
        }
        AnyLoc usedDataset = null;
        if (this.promptDef.promptMode.canUseDataset && this.promptDef.promptTemplateQueriesSource == PromptStudio.PromptTemplateQueriesSource.DATASET) {
            usedDataset = AnyLoc.resolveSmartNullSafe(this.promptStudio.projectKey, this.promptStudioPrompt.dataset);
        }
        GuardrailsPipelineSettings connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(this.authCtx, this.promptStudio.projectKey, enrichedLLMRef);
        GuardrailsPipelineSettings guardrailsPipelineSettings = GuardrailsPipelineUtils.mergeEnforcementSettings(connectionGuardrailsPipelineSettings, this.promptDef.guardrailsPipelineSettings);
        try (LLMMeshClient llmMeshClient = LLMMeshClientFactory.get(this.authCtx, this.promptStudio.projectKey, enrichedLLMRef, guardrailsPipelineSettings, usedDataset, in.size());){
            List<LLMClient.SimpleCompletionResponseOrError> responses = null;
            try (FutureProgress.AutocloseableFutureProgressState _ignored = FutureProgress.pushAutoCloseableState((String)"Querying LLM", (double)in.size(), (FutureProgressState.StateUnit)FutureProgressState.StateUnit.RECORDS);){
                responses = llmMeshClient.completeQueries(in, this.promptStudioPrompt.llmSettings.toFullSettings());
            }
            ComputeResourceUsage cru = llmMeshClient.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION);
            if (cru != null) {
                this.cruReportingService.reportComplete(cru);
            }
            PromptResponse promptResponse = new PromptResponse();
            promptResponse.promptId = this.promptStudioPrompt.id;
            promptResponse.runBy = this.authCtx.getIdentifier();
            promptResponse.runOn = System.currentTimeMillis();
            promptResponse.runId = this.promptRunId;
            promptResponse.querySource = this.promptDef.promptTemplateQueriesSource;
            switch (this.promptDef.promptMode) {
                case PROMPT_TEMPLATE_TEXT: 
                case PROMPT_TEMPLATE_STRUCTURED: {
                    assert (recordsFromPromptTemplate != null);
                    promptResponse.mainPromptTemplateInputs = this.getMainPromptTemplateInputsFromPromptDef(enrichedLLMRef.promptDriven);
                    for (int recordIdx = 0; recordIdx < recordsFromPromptTemplate.nrows(); ++recordIdx) {
                        PromptResponse.SingleInputPromptResponse sipr = new PromptResponse.SingleInputPromptResponse();
                        sipr.mainInputs = this.getMainInputsFromRecords(recordsFromPromptTemplate, recordIdx, enrichedLLMRef.promptDriven, pe.getSingleColumnForUnpromptedMode());
                        this.fillSingleInputPromptResponse(sipr, responses.get(recordIdx));
                        promptResponse.responses.add(sipr);
                    }
                    break;
                }
                case RAW_PROMPT: {
                    if (!enrichedLLMRef.promptDriven) {
                        throw new IllegalArgumentException("Cannot run a single-shot prompt with a non-promptable LLM");
                    }
                    PromptResponse.SingleInputPromptResponse sipr = new PromptResponse.SingleInputPromptResponse();
                    this.fillSingleInputPromptResponse(sipr, responses.get(0));
                    promptResponse.responses.add(sipr);
                }
            }
            for (int recordIdx = 0; recordIdx < promptResponse.responses.size(); ++recordIdx) {
                LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded(this.auditTrailService, enrichedLLMRef, llmMeshClient.getConnection(), (LLMClient.SingleCompletionQuery)in.get(recordIdx), responses.get(recordIdx));
            }
            promptResponse.stats.hasNoValidation = this.promptDef.resultValidation.hasNoSetRules();
            promptResponse.stats.testedRecords = promptResponse.responses.size();
            promptResponse.stats.validRecords = promptResponse.responses.stream().filter(pr -> !pr.error && pr.validationStatus == PromptResponse.ResponseValidationStatus.VALID).count();
            promptResponse.stats.invalidRecords = promptResponse.responses.stream().filter(pr -> !pr.error && pr.validationStatus == PromptResponse.ResponseValidationStatus.INVALID).count();
            promptResponse.stats.failedRecords = promptResponse.responses.stream().filter(pr -> pr.error).count();
            List validCosts = promptResponse.responses.stream().map(pr -> pr.estimatedCost).filter(Objects::nonNull).collect(Collectors.toList());
            promptResponse.stats.estimatedCostPer1KRecords = validCosts.stream().mapToDouble(Double::doubleValue).sum() / (double)validCosts.size() * 1000.0;
            if (Double.isNaN(promptResponse.stats.estimatedCostPer1KRecords)) {
                promptResponse.stats.estimatedCostPer1KRecords = 0.0;
            }
            PromptResponse promptResponse2 = promptResponse;
            return promptResponse2;
        }
    }

    public PromptResponse streamChatResponse(HttpServletResponse resp, MiniSSEEmitter emitter) throws Exception {
        assert (this.promptDef.promptMode == PromptStudio.PromptMode.CHAT);
        this.promptDef.lastUserMessage.runBy = this.authCtx.getIdentifier();
        if (!this.promptDef.chatMessages.containsKey(this.promptDef.lastUserMessage.id)) {
            this.promptStudiosCrudService.addNewChatMessage(this.promptDef.chatMessages, this.promptDef.lastUserMessage);
        }
        LLMClient.SingleCompletionQuery query = this.getStreamableCompletionRequestFromPrompt();
        LLMStructuredRef llmRef = LLMStructuredRef.decodeId(this.promptStudioPrompt.llmId);
        EnrichedLLMStructuredRef enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.projectKey);
        GuardrailsPipelineSettings connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(this.authCtx, this.promptStudio.projectKey, llmRef);
        GuardrailsPipelineSettings guardrailsPipelineSettings = GuardrailsPipelineUtils.mergeEnforcementSettings(connectionGuardrailsPipelineSettings, this.promptDef.guardrailsPipelineSettings);
        try (LLMClient llmClient = LLMClientFactory.get(this.authCtx, this.promptStudio.projectKey, llmRef);){
            LLMClient.SimpleCompletionResponseOrError scre = this.promptDef.streamingDisabled || !llmClient.supportsStream() || GuardrailsPipelineUtils.needsNonStreamedNonParallelProcessing(guardrailsPipelineSettings) || enrichedLLMRef.type == LLMStructuredRef.LLMType.RETRIEVAL_AUGMENTED ? this.emulateStreamCompletion(query, emitter, llmRef, guardrailsPipelineSettings) : this.streamCompletion(query, emitter, enrichedLLMRef, guardrailsPipelineSettings);
            LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded(this.auditTrailService, enrichedLLMRef, llmClient.getConnection(), query, scre);
            PromptStudio.ConversationMessage assistantMessage = this.buildAssistantMessage(scre, llmRef);
            PromptResponse promptResponse = this.buildChatPromptResponse(assistantMessage, scre);
            return promptResponse;
        }
    }

    private PromptResponse buildChatPromptResponse(PromptStudio.ConversationMessage assistantMessage, LLMClient.SimpleCompletionResponseOrError scre) {
        PromptResponse promptResponse = new PromptResponse();
        promptResponse.promptId = this.promptStudioPrompt.id;
        promptResponse.runBy = this.authCtx.getIdentifier();
        promptResponse.runOn = System.currentTimeMillis();
        promptResponse.runId = this.promptRunId;
        PromptResponse.SingleInputPromptResponse sipr = new PromptResponse.SingleInputPromptResponse();
        sipr.chatMessages = this.promptDef.chatMessages;
        sipr.lastMessageId = assistantMessage.id;
        this.fillSingleInputPromptResponse(sipr, scre);
        assistantMessage.validationStatus = sipr.validationStatus;
        assistantMessage.validationMessage = sipr.validationMessage;
        promptResponse.responses.add(sipr);
        return promptResponse;
    }

    public void forkResponse(PromptResponse promptResponse, String sourceUserMessageId) {
        promptResponse.promptId = this.promptStudioPrompt.id;
        promptResponse.runBy = this.authCtx.getIdentifier();
        promptResponse.runOn = System.currentTimeMillis();
        promptResponse.runId = this.promptRunId;
        List<PromptResponse.SingleInputPromptResponse> responses = promptResponse.responses;
        if (!responses.isEmpty()) {
            String parentId;
            PromptResponse.SingleInputPromptResponse sipr = responses.get(0);
            HashMap<String, PromptStudio.ConversationMessage> newChatMessages = new HashMap<String, PromptStudio.ConversationMessage>();
            sipr.lastMessageId = parentId = sipr.chatMessages.get((Object)sourceUserMessageId).parentId;
            while (parentId != null) {
                PromptStudio.ConversationMessage parentMessage = sipr.chatMessages.get(parentId);
                parentMessage.version = 0;
                newChatMessages.put(parentId, parentMessage);
                parentId = parentMessage.parentId;
            }
            sipr.chatMessages = newChatMessages;
        }
    }

    private PromptStudio.ConversationMessage buildAssistantMessage(LLMClient.SimpleCompletionResponseOrError scre, LLMStructuredRef llmRef) {
        PromptStudio.ConversationMessage assistantMessage = new PromptStudio.ConversationMessage();
        assistantMessage.parentId = this.promptDef.lastUserMessage.id;
        assistantMessage.message = new LLMClient.ChatMessage("assistant", scre.text);
        if (!scre.ok) {
            assistantMessage.error = true;
            assistantMessage.llmError = scre.errorMessage;
        } else if (StringUtils.isEmpty((CharSequence)scre.text)) {
            assistantMessage.error = true;
            assistantMessage.llmError = "LLM response is empty.";
        }
        assistantMessage.completionSettings = this.promptStudioPrompt.llmSettings;
        assistantMessage.llmStructuredRef = llmRef;
        assistantMessage.systemMessage = this.promptDef.chatSystemMessage;
        assistantMessage.fullTrace = scre.trace;
        if (scre.sources != null) {
            assistantMessage.sources = scre.sources;
        }
        this.promptStudiosCrudService.addNewChatMessage(this.promptDef.chatMessages, assistantMessage);
        return assistantMessage;
    }

    private LLMClient.SingleCompletionQuery getStreamableCompletionRequestFromPrompt() {
        LLMClient.SingleCompletionQuery query = new LLMClient.SingleCompletionQuery();
        query.messages = new ArrayList<LLMClient.ChatMessage>();
        query.messages.add(this.promptDef.lastUserMessage.message);
        String parentId = this.promptDef.lastUserMessage.parentId;
        while (parentId != null) {
            PromptStudio.ConversationMessage parentMessage = this.promptDef.chatMessages.get(parentId);
            if (parentMessage.message != null) {
                query.messages.add(0, parentMessage.message);
            }
            parentId = parentMessage.parentId;
        }
        if (this.promptDef.chatSystemMessage != null && !this.promptDef.chatSystemMessage.isBlank()) {
            LLMClient.ChatMessage systemMessage = new LLMClient.ChatMessage();
            systemMessage.role = "system";
            systemMessage.setTextOnly(this.promptDef.chatSystemMessage);
            query.messages.add(0, systemMessage);
        }
        return query;
    }

    /*
     * Exception decompiling
     */
    private LLMClient.SimpleCompletionResponseOrError streamCompletion(LLMClient.SingleCompletionQuery query, MiniSSEEmitter emitter, EnrichedLLMStructuredRef enrichedLLMRef, GuardrailsPipelineSettings guardrailsPipelineSettings) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [1[TRYBLOCK]], but top level block is 5[TRYBLOCK]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private LLMClient.SimpleCompletionResponseOrError emulateStreamCompletion(LLMClient.SingleCompletionQuery query, MiniSSEEmitter emitter, LLMStructuredRef llmRef, GuardrailsPipelineSettings guardrailsPipelineSettings) throws Exception {
        logger.info((Object)("Streaming is not supported. " + JSON.json((Object)llmRef)));
        try (LLMMeshClient llmMeshClient = LLMMeshClientFactory.get(this.authCtx, this.promptStudio.projectKey, llmRef, guardrailsPipelineSettings, null, 0);){
            ArrayList<LLMClient.SingleCompletionQuery> singleCompletionQueries = new ArrayList<LLMClient.SingleCompletionQuery>();
            singleCompletionQueries.add(query);
            List<Object> responses = new ArrayList();
            emitter.setInterruptCallback((ExceptionUtils.ThrowingRunnable<Exception>)((ExceptionUtils.ThrowingRunnable)() -> {
                logger.info((Object)"interrupting, closing llmClient");
                llmMeshClient.close();
            }));
            JSONObject noStreamJson = new JSONObject();
            if (this.promptDef.streamingDisabled) {
                noStreamJson.put("text", (Object)"Streaming is disabled. Response time may be longer.");
            } else {
                noStreamJson.put("text", (Object)"Selected model does not support streaming. Response time may be longer.");
            }
            emitter.sendEventWithData("no-streaming", noStreamJson.toString(), false);
            responses = llmMeshClient.completeQueries(singleCompletionQueries, this.promptStudioPrompt.llmSettings.toFullSettings());
            ComputeResourceUsage cru = llmMeshClient.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION);
            if (cru != null) {
                this.cruReportingService.reportComplete(cru);
            }
            if (emitter.isInterrupted()) {
                LLMClient.SimpleCompletionResponseOrError simpleCompletionResponseOrError2 = LLMClient.SimpleCompletionResponseOrError.fromError(new Exception("Response generation was interrupted."));
                return simpleCompletionResponseOrError2;
            }
            LLMClient.SimpleCompletionResponseOrError simpleCompletionResponseOrError = (LLMClient.SimpleCompletionResponseOrError)responses.get(0);
            return simpleCompletionResponseOrError;
        }
        catch (Exception e) {
            if (!emitter.isInterrupted()) throw e;
            return LLMClient.SimpleCompletionResponseOrError.fromError(new Exception("Response generation was interrupted."));
        }
    }

    private void fillSingleInputPromptResponse(PromptResponse.SingleInputPromptResponse sipr, LLMClient.SimpleCompletionResponseOrError resp) {
        sipr.fullTrace = resp.trace;
        if (!resp.ok) {
            sipr.error = true;
            sipr.llmError = resp.errorMessage;
            return;
        }
        if (StringUtils.isEmpty((CharSequence)resp.text)) {
            sipr.error = true;
            sipr.llmError = "LLM response is empty.";
            return;
        }
        sipr.rawLLMOutput = resp.text;
        sipr.promptTokens = resp.promptTokens == null ? 0 : resp.promptTokens;
        sipr.completionTokens = resp.completionTokens == null ? 0 : resp.completionTokens;
        sipr.totalTokens = resp.totalTokens == null ? 0 : resp.totalTokens;
        sipr.tokenCountsAreEstimated = resp.tokenCountsAreEstimated == null ? false : resp.tokenCountsAreEstimated;
        sipr.estimatedCost = resp.estimatedCost == null ? 0.0 : resp.estimatedCost;
        sipr.sources = resp.sources;
        sipr.additionalInformation = resp.additionalInformation;
        sipr.error = false;
        this.validateResponse(sipr);
    }

    private void validateResponse(PromptResponse.SingleInputPromptResponse sipr) {
        logger.info((Object)("Validating prompt result with rules: " + JSON.json((Object)this.promptDef.resultValidation)));
        sipr.validate(this.promptDef.resultValidation);
    }

    public PromptResponsePreview getPromptDatasetPreview() throws Exception {
        PromptResponsePreview promptResponsePreview = new PromptResponsePreview();
        if (this.promptDef.promptTemplateQueriesSource == PromptStudio.PromptTemplateQueriesSource.DATASET && this.promptDef.promptMode.canUseDataset) {
            EnrichedLLMStructuredRef enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.getProjectKey());
            MemTable recordsFromPromptTemplate = this.getRecordsForPromptTemplate();
            promptResponsePreview.promptId = this.promptStudioPrompt.id;
            promptResponsePreview.mainPromptTemplateInputs = this.getMainPromptTemplateInputsFromPromptDef(enrichedLLMRef.promptDriven);
            for (int recordIdx = 0; recordIdx < recordsFromPromptTemplate.nrows(); ++recordIdx) {
                PromptResponsePreview.SingleInputPromptResponsePreview siprp = new PromptResponsePreview.SingleInputPromptResponsePreview();
                siprp.mainInputs = this.getMainInputsFromRecords(recordsFromPromptTemplate, recordIdx, enrichedLLMRef.promptDriven, this.promptDef.singleInputColumn);
                promptResponsePreview.responses.add(siprp);
            }
        }
        return promptResponsePreview;
    }

    private List<PromptStudio.PromptTemplateInput> getMainPromptTemplateInputsFromPromptDef(boolean isPromptDriven) {
        ArrayList<PromptStudio.PromptTemplateInput> mainPromptTemplateInputs = new ArrayList<PromptStudio.PromptTemplateInput>();
        if (isPromptDriven) {
            for (PromptStudio.PromptTemplateInput input : this.promptDef.getInputs()) {
                PromptStudio.PromptTemplateInput newInput = new PromptStudio.PromptTemplateInput();
                newInput.name = this.promptDef.getInputName(input);
                newInput.type = input.type;
                mainPromptTemplateInputs.add(newInput);
            }
        } else {
            PromptStudio.PromptTemplateInput input = new PromptStudio.PromptTemplateInput();
            input.name = "Single input";
            mainPromptTemplateInputs.add(input);
        }
        return mainPromptTemplateInputs;
    }

    private List<String> getMainInputsFromRecords(MemTable recordsFromPromptTemplate, int recordIndex, boolean isPromptDriven, String singleInputColumn) {
        ArrayList<String> mainInputs = new ArrayList<String>();
        if (isPromptDriven) {
            for (PromptStudio.PromptTemplateInput input : this.promptDef.getInputs()) {
                String inputName = this.promptDef.getInputName(input);
                String mainInput = null;
                if (inputName != null) {
                    MemColumn mc = recordsFromPromptTemplate.getColumn(inputName);
                    mainInput = recordsFromPromptTemplate.rows.get(recordIndex).get(mc);
                }
                mainInputs.add(mainInput);
            }
        } else {
            MemColumn mc = recordsFromPromptTemplate.getColumn(singleInputColumn);
            mainInputs.add(recordsFromPromptTemplate.rows.get(recordIndex).get(mc));
        }
        return mainInputs;
    }

    public MemTable getRecordsForPromptTemplate() throws Exception {
        EnrichedLLMStructuredRef llmRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.getProjectKey());
        if (this.forcedRecordsForPromptTemplate != null) {
            return this.forcedRecordsForPromptTemplate;
        }
        switch (this.promptDef.promptTemplateQueriesSource) {
            case DATASET: {
                return this.datasetToMemTable(this.promptStudioPrompt.dataset);
            }
            case INLINE: {
                MemTable mt = new MemTable();
                if (llmRef.promptDriven) {
                    for (PromptStudio.PromptTemplateInput pti : this.promptDef.getInputs()) {
                        mt.column(pti.name);
                    }
                    for (PromptStudio.InlinePromptTemplateQuery iptq : this.promptStudioPrompt.inlinePromptTemplateQueries) {
                        Row r = mt.row();
                        for (Map.Entry<String, String> e : iptq.data.entrySet()) {
                            r.put((Column)mt.column(e.getKey()), e.getValue());
                        }
                        mt.appendRow(r);
                    }
                } else {
                    mt.column(DKU_SINGLE_INLINE_INPUT);
                    for (PromptStudio.InlinePromptTemplateQuery iptq : this.promptStudioPrompt.inlinePromptTemplateQueries) {
                        Row r = mt.row();
                        r.put((Column)mt.column(DKU_SINGLE_INLINE_INPUT), iptq.singleInputData);
                        mt.appendRow(r);
                    }
                }
                return mt;
            }
        }
        throw new Error("unreachable");
    }

    private MemTable datasetToMemTable(String datasetRef) throws Exception {
        MemTable mt = new MemTable();
        SerializedDataset sd = null;
        try (Transaction t = this.transactionService.retrieveOrBeginRead();){
            sd = (SerializedDataset)this.datasetsDAO.getMandatory(AnyLoc.resolveSmart(this.promptStudio.getProjectKey(), datasetRef));
        }
        Dataset dataset = Dataset.fromSerialized(sd);
        SingleThreadPusherToMemTable stmt = new SingleThreadPusherToMemTable(this.authCtx, dataset, mt);
        DatasetSelectionToMemTable dsmt = new DatasetSelectionToMemTable();
        dsmt.samplingMethod = SamplingParam.SamplingMethod.HEAD_SEQUENTIAL;
        dsmt.maxRecords = this.promptStudioPrompt.nbRows;
        stmt.setDatasetSelection(dsmt);
        stmt.push();
        return mt;
    }

    private static /* synthetic */ void lambda$streamCompletion$5(LLMClient llmClient) throws Exception {
        logger.info((Object)"interrupting, closing llmClient");
        llmClient.close();
    }
}

