/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.summarization;

import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.exec.AbstractInitializedRunner;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.prompts.PromptExpander;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.recipes.nlp.common.NLPLLMRecipeRunnerBase;
import com.dataiku.dip.recipes.nlp.common.NLPRecipeParallelRunInputFeedThread;
import com.dataiku.dip.recipes.nlp.summarization.SummarizationRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.summarization.SummarizationRecipeSchemaComputer;
import com.dataiku.dip.server.services.AuditPrivilegedClient;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.dataiku.dip.warnings.WarningsContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;

public class SummarizationRecipeRunner
extends NLPLLMRecipeRunnerBase {
    private SummarizationRecipePayloadParams desc;
    static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.summarization");

    public SummarizationRecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (SummarizationRecipePayloadParams)JSON.parse((String)payload, SummarizationRecipePayloadParams.class);
        this.setLlmId(this.desc.llmId);
    }

    @Override
    public void run() throws Exception {
        if (StringUtils.isBlank((CharSequence)this.desc.llmId)) {
            throw new IllegalArgumentException("No LLM was specified");
        }
        if (StringUtils.isBlank((CharSequence)this.desc.inputColumn)) {
            throw new IllegalArgumentException("No input column was specified");
        }
        AbstractInitializedRunner.Output output = (AbstractInitializedRunner.Output)((List)this.outputs.get("main")).get(0);
        StringTransmogrifier transmogrifier = this.getOutputColTransmogrifier();
        Columns columns = new Columns();
        columns.summarizedText = output.cf.column(transmogrifier.transmogrify(SummarizationRecipeSchemaComputer.SummarizationRecipeColumn.SUMMARIZED_TEXT.name));
        columns.errorMessage = output.cf.column(transmogrifier.transmogrify(SummarizationRecipeSchemaComputer.SummarizationRecipeColumn.LLM_ERROR_MSG.name));
        try (AuditPrivilegedClient auditClient = new AuditPrivilegedClient();){
            logger.info((Object)("Summarization recipe runner started, llmId=" + this.desc.llmId + " llmRef=" + JSON.log((Object)this.enrichedLLMRef)));
            ProcessorOutputToSIP processorOutput = new ProcessorOutputToSIP(output.out);
            try (CompletionRecipeLLMMeshClient meshClient = this.buildCompletionRecipeClient(null);){
                this.enrichedLLMRef = meshClient.getEnrichedRef();
                this.plcStream = meshClient.completeQueriesAsyncStream(this.buildCompletionSettings());
                InputFeedThread ift = new InputFeedThread((ColumnFactory)output.cf, (RowFactory)output.rf);
                ift.start();
                while (true) {
                    logger.info((Object)"Fetching next response from PLCS");
                    Optional o = this.plcStream.fetchNextResponse();
                    logger.info((Object)"Fetched response from PLCS");
                    if (!o.isPresent()) break;
                    Row outputRow = output.rf.row();
                    for (Map.Entry e : ((Map)((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).context).entrySet()) {
                        outputRow.put((Column)output.cf.column((String)e.getKey()), (String)e.getValue());
                    }
                    LLMClient.SimpleCompletionResponseOrError scr = ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).scr;
                    if (scr.ok) {
                        outputRow.put(columns.summarizedText, scr.text);
                    } else {
                        outputRow.put(columns.errorMessage, scr.errorMessage);
                        this.activity.warnContext.addWarning(WarningsContext.WarningType.LLM_QUERY_ERROR, scr.errorMessage, logger);
                    }
                    processorOutput.emitRow(outputRow);
                    LLMAuditHelper.emitLLMCompletionAuditFromJobIfNeeded(this.authCtx, auditClient, meshClient.getEnrichedRef(), meshClient.getConnection(), ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).completionQuery, scr);
                }
                logger.info((Object)"Terminated");
                processorOutput.lastRowEmitted();
                ift.join();
                if (ift.getException() != null) {
                    throw new IOException("Input feeding failed", ift.getException());
                }
                this.handleCRU(meshClient);
            }
        }
    }

    public static PromptDef getPrompt(SummarizationRecipePayloadParams desc, EnrichedLLMStructuredRef enrichedLLMRef) {
        PromptDef prompt = PromptDef.forRecipe();
        prompt.structuredPromptPrefix = "You are a helpful assistant that summarizes the following text.";
        if (desc.controlTargetLength && LLMStructuredRef.LLMType.HUGGINGFACE_TRANSFORMER_LOCAL != enrichedLLMRef.type) {
            prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + " Your summary must be at most " + desc.targetLength + "-" + desc.targetLengthUnit.name().toLowerCase() + " long.";
        }
        if (enrichedLLMRef.canGenerateCrossLanguageOutput) {
            prompt.structuredPromptPrefix = StringUtils.isNotBlank((CharSequence)desc.outputLanguage) ? prompt.structuredPromptPrefix + " Write your summary in the following language: " + desc.outputLanguage + "\n" : prompt.structuredPromptPrefix + " Write your summary in the same language as the original text\n";
        }
        prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "\nIf you cannot summarize, you must answer \"CANNOT_SUMMARIZE\".\n\n";
        PromptStudio.PromptTemplateInput pti = new PromptStudio.PromptTemplateInput();
        pti.name = "text to summarize";
        pti.datasetColumnName = desc.inputColumn;
        prompt.getInputs().add(pti);
        return prompt;
    }

    public LLMClient.CompletionSettings buildCompletionSettings() {
        LLMClient.CompletionSettings cs2 = new LLMClient.CompletionSettings();
        if (!this.enrichedLLMRef.promptDriven) {
            cs2.summarizationMinTokens = this.desc.huggingFaceMinTokens;
            cs2.summarizationMaxTokens = this.desc.huggingFaceMaxTokens;
            cs2.summarizationSpecialTokensSafetyFactor = this.desc.specialTokensSafetyFactor;
            cs2.summarizationNumOverlapTokens = this.desc.numOverlapTokens;
            cs2.summarizationMaxNumSplitLevels = this.desc.maxNumSplitLevels;
        }
        return cs2;
    }

    private static class Columns {
        Column summarizedText;
        Column errorMessage;

        private Columns() {
        }
    }

    private class InputFeedThread
    extends NLPRecipeParallelRunInputFeedThread {
        private PromptExpander promptExpander;

        InputFeedThread(ColumnFactory cf, RowFactory rf) throws IOException {
            super(SummarizationRecipeRunner.this.authCtx, SummarizationRecipeRunner.this.recipe, SummarizationRecipeRunner.this.activity, SummarizationRecipeRunner.this.plcStream, cf, rf);
            if (((SummarizationRecipeRunner)SummarizationRecipeRunner.this).enrichedLLMRef.promptDriven) {
                this.promptExpander = new PromptExpander(SummarizationRecipeRunner.this.enrichedLLMRef, SummarizationRecipeRunner.getPrompt(SummarizationRecipeRunner.this.desc, SummarizationRecipeRunner.this.enrichedLLMRef), null);
            }
        }

        @Override
        public LLMClient.SingleCompletionQuery buildCompletionQuery(Row row) {
            if (this.promptExpander != null) {
                return this.promptExpander.expand(this.cf, row);
            }
            String v = row.get(this.cf.column(SummarizationRecipeRunner.this.desc.inputColumn));
            LLMClient.SingleCompletionQuery cq = new LLMClient.SingleCompletionQuery();
            cq.messages.add(new LLMClient.ChatMessage("user", v));
            return cq;
        }
    }
}

