/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.llm_evaluation;

import com.dataiku.dip.analysis.model.core.LLMCustomEvaluationMetric;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.recipes.RecipeSchemaComputer;
import com.dataiku.dip.recipes.nlp.llm_evaluation.LLMEvaluationRecipePayloadParams;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import java.util.Collections;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMEvaluationRecipeSchemaComputer
extends RecipeSchemaComputer
implements RecipeSchemaComputer.RecipeSchemaComputerWithPayload {
    public static final List<SchemaColumn> BERT_SCORE_METRIC_SCHEMA_COLUMNS = List.of(new SchemaColumn("bertScorePrecision", Type.DOUBLE), new SchemaColumn("bertScoreRecall", Type.DOUBLE), new SchemaColumn("bertScoreF1", Type.DOUBLE));
    public static final List<SchemaColumn> ROUGE_METRIC_SCHEMA_COLUMNS = List.of(new SchemaColumn("rouge1Precision", Type.DOUBLE), new SchemaColumn("rouge1Recall", Type.DOUBLE), new SchemaColumn("rouge1F1", Type.DOUBLE), new SchemaColumn("rouge2Precision", Type.DOUBLE), new SchemaColumn("rouge2Recall", Type.DOUBLE), new SchemaColumn("rouge2F1", Type.DOUBLE), new SchemaColumn("rougeLPrecision", Type.DOUBLE), new SchemaColumn("rougeLRecall", Type.DOUBLE), new SchemaColumn("rougeLF1", Type.DOUBLE));
    public static final List<SchemaColumn> TOKEN_COUNT_METRIC_OUTPUT_SCHEMA_COLUMNS = List.of(new SchemaColumn("inputTokensPerRow", Type.INT), new SchemaColumn("outputTokensPerRow", Type.INT));
    public static final List<SchemaColumn> TOKEN_COUNT_METRIC_METRIC_SCHEMA_COLUMNS = List.of(new SchemaColumn("inputTokensPerRow", Type.DOUBLE), new SchemaColumn("outputTokensPerRow", Type.DOUBLE));
    public static final String PARSED_OUTPUT_NAME = "dkuParsedOutput";
    public static final String PARSED_CONTEXT_NAME = "dkuParsedContexts";
    public static final String RECONSTRUCTED_INPUT_NAME = "dkuReconstructedInput";
    @Autowired
    private TransactionService transactionService;
    @Autowired
    protected DatasetsDAO datasetsDAO;
    private LLMEvaluationRecipePayloadParams params;
    private static final DKULogger logger = DKULogger.getLogger((String)"recipes.nlp.llm_evaluation.schema");

    public LLMEvaluationRecipeSchemaComputer(AuthCtx authCtx, JobActivity activity) {
        super(authCtx, activity);
    }

    @Override
    public void setPayload(String payload) {
        this.params = (LLMEvaluationRecipePayloadParams)JSON.parse((String)payload, LLMEvaluationRecipePayloadParams.class);
    }

    @Override
    public List<Schema> getSchemasForOutputRole_NT(String role) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.recipesValidationService.checkComplianceWithRecipeDesc(this.authCtx, this.recipe);
        }
        switch (role) {
            case "main": {
                return this.recipe.getOutputsForRole("main").isEmpty() ? Collections.emptyList() : Lists.newArrayList((Object[])new Schema[]{this.getMainSchema_NT()});
            }
            case "metrics": {
                return this.recipe.getOutputsForRole("metrics").isEmpty() ? Collections.emptyList() : Lists.newArrayList((Object[])new Schema[]{this.getMetricsSchema()});
            }
            case "evaluationStore": {
                return Collections.emptyList();
            }
        }
        throw new IllegalArgumentException(String.format("Output role %s is not compatible with the LLM Evaluation recipe.", role));
    }

    private Schema getMainSchema_NT() throws Exception {
        if (this.recipe.getInputsUnsafe().get((Object)"main").items.isEmpty()) {
            throw new IllegalArgumentException("The LLM Evaluation recipe requires an input dataset.");
        }
        Schema schema = this.getInputSchemaCopy();
        for (String s : this.params.metrics) {
            if ("bertScore".equals(s)) {
                schema.addColumns(BERT_SCORE_METRIC_SCHEMA_COLUMNS);
                continue;
            }
            if ("rouge".equals(s)) {
                schema.addColumns(ROUGE_METRIC_SCHEMA_COLUMNS);
                continue;
            }
            if ("multimodalFaithfulness".equals(s) || "multimodalRelevancy".equals(s)) {
                if (!LLMEvaluationRecipePayloadParams.LLMEvalInputFormat.PROMPT_RECIPE.equals((Object)this.params.inputFormat)) {
                    logger.infoV("The input format of the evaluation dataset is not Prompt Recipe (is %s): not considering multimodal metric %s", new Object[]{this.params.inputFormat, s});
                    continue;
                }
                schema.addColumn(new SchemaColumn(s, Type.DOUBLE));
                continue;
            }
            schema.addColumn(new SchemaColumn(s, Type.DOUBLE));
        }
        for (LLMCustomEvaluationMetric m : this.params.customMetrics) {
            schema.addColumn(new SchemaColumn(m.name, Type.DOUBLE));
        }
        if (this.params.inputFormat == LLMEvaluationRecipePayloadParams.LLMEvalInputFormat.PROMPT_RECIPE) {
            schema.addColumn(RECONSTRUCTED_INPUT_NAME, Type.STRING);
            schema.addColumn(PARSED_OUTPUT_NAME, Type.STRING);
            schema.addColumn(PARSED_CONTEXT_NAME, Type.STRING);
            schema.addColumns(TOKEN_COUNT_METRIC_OUTPUT_SCHEMA_COLUMNS);
        } else if (this.params.inputFormat == LLMEvaluationRecipePayloadParams.LLMEvalInputFormat.DATAIKU_ANSWERS) {
            schema.addColumn(PARSED_CONTEXT_NAME, Type.STRING);
        }
        logger.infoV("Output dataset schema : %s", new Object[]{schema});
        return schema;
    }

    private Schema getInputSchemaCopy() throws Exception {
        Dataset inputDataset;
        try (Transaction t = this.transactionService.beginRead();){
            AnyLoc inputDatasetLoc = this.recipe.getSingleInput("main").getLoc(this.recipe.getProjectKey());
            inputDataset = this.datasetAccessService.getMandatoryUnsafe(inputDatasetLoc);
        }
        return inputDataset.getSchema().getCopy();
    }

    private Schema getMetricsSchema() {
        Schema schema = new Schema();
        schema.addColumn(new SchemaColumn("date", Type.DATE));
        for (String s : this.params.metrics) {
            if ("multimodalFaithfulness".equals(s) || "multimodalRelevancy".equals(s)) {
                if (!LLMEvaluationRecipePayloadParams.LLMEvalInputFormat.PROMPT_RECIPE.equals((Object)this.params.inputFormat)) {
                    logger.infoV("The input format of the evaluation dataset is not Prompt Recipe (is %s): not considering multimodal metric %s", new Object[]{this.params.inputFormat, s});
                    continue;
                }
                schema.addColumn(new SchemaColumn(s, Type.DOUBLE));
                continue;
            }
            if ("bertScore".equals(s)) {
                schema.addColumns(BERT_SCORE_METRIC_SCHEMA_COLUMNS);
                continue;
            }
            if ("rouge".equals(s)) {
                schema.addColumns(ROUGE_METRIC_SCHEMA_COLUMNS);
                continue;
            }
            schema.addColumn(new SchemaColumn(s, Type.DOUBLE));
        }
        if (this.params.inputFormat == LLMEvaluationRecipePayloadParams.LLMEvalInputFormat.PROMPT_RECIPE) {
            schema.addColumns(TOKEN_COUNT_METRIC_METRIC_SCHEMA_COLUMNS);
        }
        for (LLMCustomEvaluationMetric m : this.params.customMetrics) {
            schema.addColumn(new SchemaColumn(m.name, Type.DOUBLE));
        }
        logger.infoV("Metrics schema: %s", new Object[]{schema});
        return schema;
    }
}

