/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.classification.model_provided;

import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.exec.AbstractInitializedRunner;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.prompts.PromptExpander;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.recipes.nlp.classification.model_provided.NLPLLMModelProvidedClassificationRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.classification.model_provided.NLPLLMModelProvidedClassificationRecipeSchemaComputer;
import com.dataiku.dip.recipes.nlp.common.NLPLLMRecipeRunnerBase;
import com.dataiku.dip.recipes.nlp.common.NLPRecipeParallelRunInputFeedThread;
import com.dataiku.dip.server.services.AuditPrivilegedClient;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.dataiku.dip.warnings.WarningsContext;
import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;

public class NLPLLMModelProvidedClassificationRecipeRunner
extends NLPLLMRecipeRunnerBase {
    private NLPLLMModelProvidedClassificationRecipePayloadParams desc;
    private Column outputClazzCD;
    private Column outputScoreCD;
    private Column rawLLMOutputCD;
    private Column errorMessageCD;
    static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.builtin_classes_classification");

    public NLPLLMModelProvidedClassificationRecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (NLPLLMModelProvidedClassificationRecipePayloadParams)JSON.parse((String)payload, NLPLLMModelProvidedClassificationRecipePayloadParams.class);
        this.setLlmId(this.desc.llmId);
    }

    @Override
    public void run() throws Exception {
        logger.info((Object)"Classification recipe API runner started");
        if (StringUtils.isBlank((CharSequence)this.desc.llmId)) {
            throw new IllegalArgumentException("No LLM was specified");
        }
        if (StringUtils.isBlank((CharSequence)this.desc.inputColumn)) {
            throw new IllegalArgumentException("No input column was specified");
        }
        AbstractInitializedRunner.Output output = (AbstractInitializedRunner.Output)((List)this.outputs.get("main")).get(0);
        StringTransmogrifier transmogrifier = this.getOutputColTransmogrifier();
        this.outputClazzCD = output.cf.column(transmogrifier.transmogrify(NLPLLMModelProvidedClassificationRecipeSchemaComputer.NLPLLMModelProvidedClassificationRecipeColumn.OUTPUT.name));
        this.outputScoreCD = output.cf.column(transmogrifier.transmogrify(NLPLLMModelProvidedClassificationRecipeSchemaComputer.NLPLLMModelProvidedClassificationRecipeColumn.OUTPUT_SCORE.name));
        this.rawLLMOutputCD = output.cf.column(transmogrifier.transmogrify(NLPLLMModelProvidedClassificationRecipeSchemaComputer.NLPLLMModelProvidedClassificationRecipeColumn.LLM_RAW_OUTPUT.name));
        this.errorMessageCD = output.cf.column(transmogrifier.transmogrify(NLPLLMModelProvidedClassificationRecipeSchemaComputer.NLPLLMModelProvidedClassificationRecipeColumn.LLM_ERROR_MSG.name));
        try (AuditPrivilegedClient auditClient = new AuditPrivilegedClient();){
            ProcessorOutputToSIP processorOutput = new ProcessorOutputToSIP(output.out);
            try (CompletionRecipeLLMMeshClient meshClient = this.buildCompletionRecipeClient(null);){
                this.enrichedLLMRef = meshClient.getEnrichedRef();
                this.plcStream = meshClient.completeQueriesAsyncStream(this.buildCompletionSettings());
                InputFeedThread ift = new InputFeedThread((ColumnFactory)output.cf, (RowFactory)output.rf);
                ift.start();
                while (true) {
                    LLMClient.SimpleCompletionResponseOrError scr;
                    Row outputRow;
                    Optional o;
                    block33: {
                        logger.info((Object)"Fetching next response from PLCS");
                        o = this.plcStream.fetchNextResponse();
                        logger.info((Object)"Fetched response from PLCS");
                        if (!o.isPresent()) break;
                        outputRow = output.rf.row();
                        for (Map.Entry e : ((Map)((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).context).entrySet()) {
                            outputRow.put((Column)output.cf.column((String)e.getKey()), (String)e.getValue());
                        }
                        scr = ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).scr;
                        if (scr.ok) {
                            try {
                                JsonObject jo;
                                if (this.enrichedLLMRef.promptDriven) {
                                    if (LLMClient.CompletionSettings.ClassificationOutputMode.ALL.equals((Object)this.desc.outputMode)) {
                                        jo = (JsonObject)JSON.parse((String)scr.text, JsonObject.class);
                                        if (jo == null) {
                                            throw new IOException("Failed to parse JSON dict from LLM response");
                                        }
                                        LinkedHashMap<String, String> dict = new LinkedHashMap<String, String>();
                                        for (Map.Entry elt : jo.entrySet()) {
                                            if (!((JsonElement)elt.getValue()).isJsonPrimitive()) {
                                                throw new IOException("Unexpected LLM response shape");
                                            }
                                            dict.put((String)elt.getKey(), ((JsonElement)elt.getValue()).getAsString());
                                        }
                                        outputRow.put(this.outputClazzCD, JSON.json(dict));
                                    } else {
                                        outputRow.put(this.outputClazzCD, scr.text);
                                    }
                                    break block33;
                                }
                                if (this.enrichedLLMRef.type == LLMStructuredRef.LLMType.HUGGINGFACE_TRANSFORMER_LOCAL) {
                                    if (LLMClient.CompletionSettings.ClassificationOutputMode.ALL.equals((Object)this.desc.outputMode)) {
                                        logger.info((Object)("Raw LLM answer: " + scr.text));
                                        jo = (JsonObject)JSON.parse((String)scr.text, JsonObject.class);
                                        if (jo == null) {
                                            throw new IOException("Failed to parse JSON dict from LLM response");
                                        }
                                        Map<String, Double> probas = new LinkedHashMap();
                                        for (Map.Entry elt : jo.entrySet()) {
                                            if (!((JsonElement)elt.getValue()).isJsonPrimitive()) {
                                                throw new IOException("Unexpected LLM response shape");
                                            }
                                            double p = ((JsonElement)elt.getValue()).getAsDouble();
                                            probas.put((String)elt.getKey(), p);
                                        }
                                        probas = probas.entrySet().stream().sorted(Map.Entry.comparingByValue((v1, v2) -> v2.compareTo((Double)v1))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
                                        outputRow.put(this.outputClazzCD, JSON.json(probas));
                                    } else {
                                        jo = (JsonObject)JSON.parse((String)scr.text, JsonObject.class);
                                        if (jo == null) {
                                            throw new IOException("Failed to parse JSON dict from LLM response");
                                        }
                                        outputRow.put(this.outputClazzCD, jo.get("label").getAsString());
                                        outputRow.put(this.outputScoreCD, jo.get("score").getAsString());
                                    }
                                    break block33;
                                }
                                throw new IllegalArgumentException("Do not know how to handle output of LLM: " + this.enrichedLLMRef.id + " for task " + String.valueOf((Object)this.desc.task));
                            }
                            catch (Exception e) {
                                outputRow.put(this.rawLLMOutputCD, scr.text);
                                String errorMessage = "LLM response did not match expectation:" + ExceptionUtils.getMessageWithCauses((Throwable)e);
                                outputRow.put(this.errorMessageCD, errorMessage);
                                this.activity.warnContext.addWarning(WarningsContext.WarningType.LLM_QUERY_ERROR, errorMessage, logger);
                            }
                        } else {
                            outputRow.put(this.errorMessageCD, scr.errorMessage);
                            this.activity.warnContext.addWarning(WarningsContext.WarningType.LLM_QUERY_ERROR, scr.errorMessage, logger);
                        }
                    }
                    processorOutput.emitRow(outputRow);
                    LLMAuditHelper.emitLLMCompletionAuditFromJobIfNeeded(this.authCtx, auditClient, this.enrichedLLMRef, meshClient.getConnection(), ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).completionQuery, scr);
                }
                logger.info((Object)"Terminated");
                processorOutput.lastRowEmitted();
                ift.join();
                if (ift.getException() != null) {
                    throw new IOException("Input feeding failed", ift.getException());
                }
                this.handleCRU(meshClient);
            }
        }
    }

    public static PromptDef getSentimentAnalysisPrompt(NLPLLMModelProvidedClassificationRecipePayloadParams desc) {
        PromptDef prompt = PromptDef.forRecipe();
        PromptStudio.PromptTemplateInput pti = new PromptStudio.PromptTemplateInput();
        pti.name = "text to classify";
        pti.datasetColumnName = desc.inputColumn;
        prompt.getInputs().add(pti);
        String lang = desc.lang == null ? "en" : desc.lang;
        block4 : switch (lang.toLowerCase(Locale.ENGLISH)) {
            case "en": {
                switch (desc.outputMode) {
                    case ALL: {
                        prompt.structuredPromptPrefix = "You are a helpful assistant that analyzes sentiment in the following text. You must answer a JSON object, with 3 keys: positive, neutral, negative.The value for each key must be the intensity of the sentiment, among these five grades only: overwhelmingly, very much, quite, not much, not at all.";
                        prompt.structuredPromptExamples = Lists.newArrayList((Object[])new PromptStudio.StructuredPromptTemplateExample[]{PromptStudio.StructuredPromptTemplateExample.newSingleInput("It was really bad, I could not stand it", "{\"negative\": \"overwhelmingly\", \"positive\": \"not much\", \"neutral\": \"not at all\"}")});
                        break block4;
                    }
                    case FIRST: {
                        prompt.structuredPromptPrefix = "You are a helpful assistant that analyzes sentiment in the following text. You must answer either 'positive', 'negative' or 'neutral'. You must not answer anything else.\n";
                        prompt.structuredPromptExamples = Lists.newArrayList((Object[])new PromptStudio.StructuredPromptTemplateExample[]{PromptStudio.StructuredPromptTemplateExample.newSingleInput("It was really bad, I could not stand it", "negative"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("I am in love with this work", "positive"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("The Eiffel Tower is 320 meters high", "neutral")});
                        break block4;
                    }
                }
                throw new IllegalArgumentException("Sentiment analysis cannot be used with the output mode " + String.valueOf((Object)desc.outputMode));
            }
            case "fr": {
                switch (desc.outputMode) {
                    case ALL: {
                        prompt.structuredPromptPrefix = "Tu es un assistant qui analyse le sentiment du texte suivant. Tu dois r\u00e9pondre un dictionnaire JSON contenant 3 cl\u00e9s: positive, neutral, negative.La valeur de chaque cl\u00e9 doit \u00eatre l'intensit\u00e9 de ce sentiment, uniquement parmis l'\u00e9chelle suivante: extr\u00eamement, beaucoup, un peu, pas beaucoup, pas du tout";
                        prompt.structuredPromptExamples = Lists.newArrayList((Object[])new PromptStudio.StructuredPromptTemplateExample[]{PromptStudio.StructuredPromptTemplateExample.newSingleInput("C'\u00e9tait vraiment nul, je ne pouvais pas le supporter", "{\"negative\": \"extr\u00eamement\", \"positive\": \"pas beaucoup\", \"neutral\": \"pas du tout\"}")});
                        break block4;
                    }
                    case FIRST: {
                        prompt.structuredPromptPrefix = "Tu es un assistant qui analyse le sentiment du texte suivant. Tu dois r\u00e9pondre soit 'positive', 'neutral' ou 'negative'. Ne r\u00e9ponds rien d'autre.\n";
                        prompt.structuredPromptExamples = Lists.newArrayList((Object[])new PromptStudio.StructuredPromptTemplateExample[]{PromptStudio.StructuredPromptTemplateExample.newSingleInput("C'\u00e9tait vraiment nul, je ne pouvais pas le supporter", "negative"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("J'adore ce travail", "positive"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("La Tour Eiffel mesure 320 m\u00e8tres", "neutral")});
                        break block4;
                    }
                }
                throw new IllegalArgumentException("Sentiment analysis cannot be used with the output mode " + String.valueOf((Object)desc.outputMode));
            }
            default: {
                throw new IllegalArgumentException("Unhandled language: " + lang);
            }
        }
        return prompt;
    }

    public static PromptDef getEmotionAnalysisPrompt(NLPLLMModelProvidedClassificationRecipePayloadParams desc) {
        ArrayList emotions = Lists.newArrayList((Object[])new String[]{"disappointment", "sadness", "annoyance", "neutral", "disapproval", "realization", "nervousness", "approval", "joy", "anger", "embarrassment", "caring", "remorse", "disgust", "grief", "confusion", "relief", "desire", "admiration", "optimism", "fear", "love", "excitement", "curiosity", "amusement", "surprise", "gratitude", "pride"});
        PromptDef prompt = PromptDef.forRecipe();
        PromptStudio.PromptTemplateInput pti = new PromptStudio.PromptTemplateInput();
        pti.name = "text";
        pti.datasetColumnName = desc.inputColumn;
        prompt.getInputs().add(pti);
        String lang = desc.lang == null ? "en" : desc.lang;
        block4 : switch (lang.toLowerCase(Locale.ENGLISH)) {
            case "en": {
                switch (desc.outputMode) {
                    case ALL: {
                        prompt.structuredPromptPrefix = String.format("You are a helpful assistant that detects which emotions are expressed in the following text. The possible emotions are: %s. 'neutral' is the default if no other specific emotion can be observed. You must answer with a JSON object containing  one key for each of the previously listed emotions. The value for each key must be the intensity of the emotion/how well it describes the overall tone of the text, among these five grades only: overwhelmingly, very much, quite, not much, not at all. You must not answer with anything else.", String.join((CharSequence)", ", emotions));
                        prompt.structuredPromptExamples = Lists.newArrayList((Object[])new PromptStudio.StructuredPromptTemplateExample[]{PromptStudio.StructuredPromptTemplateExample.newSingleInput("I appreciate it, that's good to know. I hope I'll have to apply that knowledge one day", "{\"gratitude\": \"overwhelmingly\", \"desire\": \"overwhelmingly\", \"admiration\": \"not much\", \"optimism\": \"not much\", \"approval\": \"not at all\", \"caring\": \"not at all\",\"joy\": \"not at all\", \"love\": \"not at all\", \"excitement\": \"not at all\", \"pride\": \"not at all\", \"neutral\": \"not at all\", \"disapproval\": \"not at all\",\"realization\": \"not at all\", \"curiosity\": \"not at all\", \"disappointment\": \"not at all\", \"surprise\": \"not at all\", \"annoyance\": \"not at all\",\"remorse\": \"not at all\", \"confusion\": \"not at all\", \"relief\": \"not at all\", \"amusement\": \"not at all\", \"anger\": \"not at all\", \"sadness\": \"not at all\",\"fear\": \"not at all\", \"disgust\": \"not at all\", \"grief\": \"not at all\", \"nervousness\": \"not at all\", \"embarrassment\": \"not at all\"}"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("I have a deep seeded hatred for close-minded people too.", "{\"anger\": \"very much\", \"annoyance\": \"quite\", \"disgust\": \"not much\", \"disapproval\": \"not much\", \"neutral\":  \"not at all\", \"approval\": \"not at all\",\"sadness\": \"not at all\", \"disappointment\": \"not at all\", \"caring\": \"not at all\", \"admiration\": \"not at all\", \"realization\": \"not at all\",\"curiosity\": \"not at all\", \"fear\": \"not at all\", \"confusion\": \"not at all\", \"optimism\": \"not at all\", \"love\": \"not at all\", \"gratitude\": \"not at all\",\"embarrassment\": \"not at all\", \"desire\": \"not at all\", \"remorse\": \"not at all\", \"excitement\": \"not at all\", \"joy\": \"not at all\", \"grief\": \"not at all\",\"amusement\": \"not at all\", \"surprise\": \"not at all\", \"pride\": \"not at all\", \"nervousness\": \"not at all\", \"relief\": \"not at all\"}")});
                        break block4;
                    }
                    case FIRST: {
                        prompt.structuredPromptPrefix = String.format("You are a helpful assistant that detects which emotions are expressed in the following text. The possible emotions are: %s. 'neutral' is the default if no other specific emotion can be observed. You must answer with the one emotion that is expressed the most. The emotion should properly describe the overall tone of the text. You must not answer with anything else.", String.join((CharSequence)", ", emotions));
                        prompt.structuredPromptExamples = Lists.newArrayList((Object[])new PromptStudio.StructuredPromptTemplateExample[]{PromptStudio.StructuredPromptTemplateExample.newSingleInput("I have a deep seeded hatred for close-minded people too.", "anger"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("I appreciate it, that's good to know. I hope I'll have to apply that knowledge one day", "gratitude"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("I am not having a great day.", "sadness")});
                        break block4;
                    }
                    case MOST_RELEVANT: {
                        prompt.structuredPromptPrefix = String.format("You are a helpful assistant that detects which emotions are expressed in the following text. The possible emotions are: %s. 'neutral' is the default if no other specific emotion can be observed. You must answer with the emotions (one or more) that are expressed the most, in decreasing intensity. You should separate each emotion by a comma and a blank space. These emotions should properly describe the overall tone of the text. You must not answer with anything else.", String.join((CharSequence)", ", emotions));
                        prompt.structuredPromptExamples = Lists.newArrayList((Object[])new PromptStudio.StructuredPromptTemplateExample[]{PromptStudio.StructuredPromptTemplateExample.newSingleInput("I have a deep seeded hatred for close-minded people too.", "anger"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("I appreciate it, that's good to know. I hope I'll have to apply that knowledge one day", "gratitude, desire"), PromptStudio.StructuredPromptTemplateExample.newSingleInput("I am not having a great day.", "sadness, disappointment")});
                        break block4;
                    }
                }
                throw new IllegalArgumentException("Emotion analysis cannot be used with the output mode " + String.valueOf((Object)desc.outputMode));
            }
            case "fr": {
                break;
            }
            default: {
                throw new IllegalArgumentException("Unhandled language: " + lang);
            }
        }
        return prompt;
    }

    private LLMClient.CompletionSettings buildCompletionSettings() {
        LLMClient.CompletionSettings cs2 = new LLMClient.CompletionSettings();
        if (!this.enrichedLLMRef.promptDriven) {
            cs2.textClassificationOutputMode = this.desc.outputMode;
        }
        return cs2;
    }

    private class InputFeedThread
    extends NLPRecipeParallelRunInputFeedThread {
        private PromptExpander promptExpander;

        InputFeedThread(ColumnFactory cf, RowFactory rf) throws IOException {
            super(NLPLLMModelProvidedClassificationRecipeRunner.this.authCtx, NLPLLMModelProvidedClassificationRecipeRunner.this.recipe, NLPLLMModelProvidedClassificationRecipeRunner.this.activity, NLPLLMModelProvidedClassificationRecipeRunner.this.plcStream, cf, rf);
            switch (NLPLLMModelProvidedClassificationRecipeRunner.this.desc.task) {
                case SENTIMENT_ANALYSIS: {
                    if (((NLPLLMModelProvidedClassificationRecipeRunner)NLPLLMModelProvidedClassificationRecipeRunner.this).enrichedLLMRef.promptDriven) {
                        this.promptExpander = new PromptExpander(NLPLLMModelProvidedClassificationRecipeRunner.this.enrichedLLMRef, NLPLLMModelProvidedClassificationRecipeRunner.getSentimentAnalysisPrompt(NLPLLMModelProvidedClassificationRecipeRunner.this.desc), NLPLLMModelProvidedClassificationRecipeRunner.this.recipe.getProjectKey());
                        break;
                    }
                    if (((NLPLLMModelProvidedClassificationRecipeRunner)NLPLLMModelProvidedClassificationRecipeRunner.this).enrichedLLMRef.canDoNativeSentimentAnalysis) break;
                    logger.warn((Object)"This LLM is not designed to do sentiment analysis. Results may not match expectations");
                    break;
                }
                case EMOTION_ANALYSIS: {
                    if (((NLPLLMModelProvidedClassificationRecipeRunner)NLPLLMModelProvidedClassificationRecipeRunner.this).enrichedLLMRef.promptDriven) {
                        this.promptExpander = new PromptExpander(NLPLLMModelProvidedClassificationRecipeRunner.this.enrichedLLMRef, NLPLLMModelProvidedClassificationRecipeRunner.getEmotionAnalysisPrompt(NLPLLMModelProvidedClassificationRecipeRunner.this.desc), NLPLLMModelProvidedClassificationRecipeRunner.this.recipe.getProjectKey());
                        break;
                    }
                    if (((NLPLLMModelProvidedClassificationRecipeRunner)NLPLLMModelProvidedClassificationRecipeRunner.this).enrichedLLMRef.canDoNativeEmotionAnalysis) break;
                    logger.warn((Object)"This LLM is not designed to do emotion analysis. Results may not match expectations");
                    break;
                }
                case OTHER: {
                    if (!((NLPLLMModelProvidedClassificationRecipeRunner)NLPLLMModelProvidedClassificationRecipeRunner.this).enrichedLLMRef.promptDriven) break;
                    throw new IllegalArgumentException("The LLM that you selected requires specific instructions for each task. It can thus not be used for 'Other' classification. Please either choose a specialized LLM or use the Prompt recipe instead, providing your own instructions");
                }
            }
        }

        @Override
        public LLMClient.SingleCompletionQuery buildCompletionQuery(Row row) {
            if (this.promptExpander != null) {
                return this.promptExpander.expand(this.cf, row);
            }
            String v = row.get(this.cf.column(NLPLLMModelProvidedClassificationRecipeRunner.this.desc.inputColumn));
            LLMClient.SingleCompletionQuery cq = new LLMClient.SingleCompletionQuery();
            cq.messages.add(new LLMClient.ChatMessage("user", v));
            return cq;
        }
    }
}

