# encoding: utf-8
"""
Execute a LLM Evaluation recipe. Must be called in a Flow environment
"""

import logging
import sys

import dataiku
import pandas as pd
import json

from dataiku import Dataset
from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging, schema_handling
from dataiku.doctor.evaluation.base import sample_and_store_dataframe
from dataiku.doctor import utils, step_constants
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.utils.listener import ProgressListener, DiagOnlyContext
from dataiku.llm.evaluation.exceptions import LLMEvalException
from dataiku.llm.evaluation.llm_metrics_input import LLMMetricInput
from dataiku.llm.evaluation.llm_eval_recipe_desc import LLMEvalRecipeDesc
from dataiku.llm.evaluation.utils import bert_score_utils, bleu_utils, rouge_utils, prompt_recipe_utils, token_count_utils
from dataiku.llm.evaluation.utils.bert_score_utils import compute_bert_score, has_bert_score
from dataiku.llm.evaluation.utils.bleu_utils import compute_bleu, has_bleu
from dataiku.llm.evaluation.utils.custom_metrics_utils import RecipeParamsForCustomMetrics, compute_custom_metric
from dataiku.llm.evaluation.utils import dataiku_answers_utils
from dataiku.llm.evaluation.utils.failure_utils import raise_or_continue
from dataiku.llm.evaluation.utils.rouge_utils import has_rouge, compute_rouge
from dataiku.llm.evaluation.utils.token_count_utils import compute_token_count, can_token_count

from importlib.metadata import version
ragas_version = version("ragas")
if not ragas_version:
    raise Exception("ragas package is missing.")
if ragas_version.startswith("0.1"):
    from dataiku.llm.evaluation.utils.ragas.ragas_utils_0_1_10 import check_use_ragas_metrics, has_ragas_metrics, RagasMetricsComputer, \
        has_context_based_metrics, create_empty_ragas_metrics
elif ragas_version.startswith("0.2"):
    from dataiku.llm.evaluation.utils.ragas.ragas_utils_0_2_12 import check_use_ragas_metrics, has_ragas_metrics, RagasMetricsComputer, \
        has_context_based_metrics, create_empty_ragas_metrics
else:
    raise Exception(f"Version of ragas {ragas_version} is not supported (only 0.1.X and 0.2.X)")

logger = logging.getLogger(__name__)

# When parsing a "raw response" json from a prompt recipe, extract the output to this column
PARSED_OUTPUT_NAME = 'dkuParsedOutput'  # Keep in sync with LLMEvaluationRecipeSchemaComputer.java
# When parsing a "raw response" json from a prompt recipe or answer plugin, extract the context to this column (as an array of string)
PARSED_CONTEXT_NAME = 'dkuParsedContexts'   # Keep in sync with LLMEvaluationRecipeSchemaComputer.java
# "Full" input from the prompt recipe, with system prompt, examples, etc...
RECONSTRUCTED_INPUT_NAME = 'dkuReconstructedInput'   # Keep in sync with LLMEvaluationRecipeSchemaComputer.java


class LLMEvaluationRecipe(object):
    run_folder_path: str
    evaluation_store_folder_path: str
    input_dataset_smartname: str
    output_dataset_smartname: str
    metrics_dataset_smartname: str

    recipe_desc: LLMEvalRecipeDesc

    input_column_name: str
    output_column_name: str
    ground_truth_column_name: str
    context_column_name: str

    has_ground_truth: bool
    has_context: bool
    has_model_evaluation_store: bool
    has_output_dataset: bool
    has_metrics_dataset: bool

    output_df: pd.DataFrame = pd.DataFrame()
    metrics_df: pd.DataFrame = pd.DataFrame.from_dict({'date': [utils.get_datetime_now_utc()]})

    perf: dict = {'metrics':{}}

    ragas_max_workers: int

    def __init__(self, recipe_desc, run_folder, model_evaluation_folder, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname, ragas_max_workers):
        self.run_folder = run_folder
        self.model_evaluation_folder = model_evaluation_folder

        self.input_dataset_smartname = input_dataset_smartname
        self.output_dataset_smartname = output_dataset_smartname
        self.metrics_dataset_smartname = metrics_dataset_smartname

        self.recipe_desc = recipe_desc

        # Define default input, output and context columns of recipe_desc
        # If you create the recipe with the public API , there are not automatically defined so we infer them.
        if self.recipe_desc.input_format in ["PROMPT_RECIPE", "DATAIKU_ANSWERS"]:
            if self.recipe_desc.input_format == "PROMPT_RECIPE" :
                expected_columns = {
                    "input_column_name": prompt_recipe_utils.PROMPT_RECIPE_RAW_QUERY_NAME,
                    "output_column_name": prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME,
                    "context_column_name": prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME
                }
            else : # DATAIKU_ANSWERS
                expected_columns = {
                    "input_column_name": dataiku_answers_utils.DATAIKU_ANSWERS_QUESTION_NAME,
                    "output_column_name": dataiku_answers_utils.DATAIKU_ANSWERS_ANSWER_NAME,
                    "context_column_name": dataiku_answers_utils.DATAIKU_ANSWERS_SOURCES_NAME,
                }
            for column_name, expected_value in expected_columns.items():
                actual_value = getattr(recipe_desc, column_name, None)
                if actual_value is not None and actual_value != expected_value:
                    raise LLMEvalException(
                        f"{column_name} should be set to '{expected_value}' instead of '{actual_value}'."
                    )
                setattr(recipe_desc, column_name,expected_value)

        self.input_column_name = recipe_desc.input_column_name
        self.output_column_name = recipe_desc.output_column_name
        self.ground_truth_column_name = recipe_desc.ground_truth_column_name
        self.context_column_name = recipe_desc.context_column_name
        self.custom_metrics = recipe_desc.custom_metrics

        self.has_ground_truth = recipe_desc.ground_truth_column_name is not None
        self.has_context = recipe_desc.context_column_name is not None
        self.has_model_evaluation_store = self.model_evaluation_folder is not None
        self.has_output_dataset = self.output_dataset_smartname is not None and len(self.output_dataset_smartname) > 0
        self.has_metrics_dataset = self.metrics_dataset_smartname is not None and len(
            self.metrics_dataset_smartname) > 0
        self.ragas_max_workers = int(ragas_max_workers)

    @classmethod
    def build(cls, run_folder_path, model_evaluation_folder_path, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname, ragas_max_workers):
        run_folder = build_folder_context(run_folder_path)
        model_evaluation_folder = build_folder_context(model_evaluation_folder_path) if model_evaluation_folder_path else None
        recipe_desc = LLMEvalRecipeDesc(run_folder.read_json("desc.json"))
        return cls(recipe_desc, run_folder, model_evaluation_folder, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname, ragas_max_workers)

    @classmethod
    def build_for_test(cls, recipe_desc, input_dataset_smartname):
        """
        Create an LLMEvaluationRecipe that can be used to test custom metrics, through test_custom_metric
        :param recipe_desc:
        :param input_dataset_smartname:
        :return:
        """
        return cls(recipe_desc, None, None, input_dataset_smartname, None, None, 0)

    def run(self):
        default_diagnostics.register_evaluation_callbacks()
        listener = ProgressListener(context=DiagOnlyContext(self.run_folder))

        input_df = self._get_input_df()

        self._run_sanity_checks(input_df)

        if self.has_output_dataset or self.has_model_evaluation_store:
            self.output_df = input_df.copy()

        with listener.push_step(step_constants.ProcessingStep.STEP_EVAL_PROCESSING):
            interpreted_columns = self._craft_interpreted_columns(input_df)

        with listener.push_step(step_constants.ProcessingStep.STEP_EVAL_PROCESSING):
            self._compute_metrics(input_df, interpreted_columns)

        if self.has_model_evaluation_store:
            logger.info("Writing performance in ME folder")
            self.model_evaluation_folder.write_json("llm_perf.json", self.perf)
            evaluation_file = "_evaluation.json"
            evaluation = self.model_evaluation_folder.read_json(evaluation_file)
            evaluation['nbEvaluationRows'] = input_df.shape[0]
            self.model_evaluation_folder.write_json(evaluation_file, evaluation)
            output_df_schema = {"columns": schema_handling.get_schema_from_df(self.output_df)}
            sample_and_store_dataframe(self.model_evaluation_folder,
                                       self.output_df,
                                       output_df_schema,
                                       filename='sample_with_metrics.csv.gz',
                                       schema_filename='sample_with_metrics_schema.json',
                                       limit_sampling=True)

        if self.has_output_dataset:
            output_dataset = Dataset(self.output_dataset_smartname)
            logger.info("Writing output dataset")
            output_dataset.write_dataframe(self.output_df)

        if self.has_metrics_dataset:
            metrics_dataset = Dataset(self.metrics_dataset_smartname)
            logger.info("Writing metrics data")
            metrics_dataset.write_dataframe(self.metrics_df)

    def test_custom_metric(self, custom_metric):
        """
        Run a given custom_metric on the first 5 lines of the input dataset.
        Useful to assert the correctness of a given custom_metric.
        Do not write to MES or any output dataset.
        Returns the result of the custom_metric on the 5 first lines.
        :param custom_metric:
        :return:
        """
        input_df = self._get_input_df()
        self._run_sanity_checks(input_df)
        sample_df = input_df[0:5]
        recipe_params_for_custom_metrics = RecipeParamsForCustomMetrics(self.recipe_desc, self.recipe_desc.completion_settings)
        interpreted_columns = self._craft_interpreted_columns(sample_df)
        logger.info("Custom metric: %s will be tested" % custom_metric.name)
        result = compute_custom_metric(custom_metric, sample_df, interpreted_columns, False, recipe_params_for_custom_metrics)
        return result

    def _run_sanity_checks(self, input_df):
        if self.input_column_name is not None and self.input_column_name not in input_df.columns:
            raise LLMEvalException(
                f"Input column {self.input_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.output_column_name is not None and self.output_column_name not in input_df.columns:
            raise LLMEvalException(
                f"Output column {self.output_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.ground_truth_column_name is not None and self.ground_truth_column_name not in input_df.columns:
            raise LLMEvalException(
                f"Ground truth column {self.ground_truth_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.context_column_name is not None and self.context_column_name not in input_df.columns:
            raise LLMEvalException(
                f"Context column {self.context_column_name} is not a column of {self.input_dataset_smartname}.")
        if len(input_df) == 0:
            raise LLMEvalException(
                "The evaluation dataset is empty. Please check your recipe configuration (dataset, sampling parameters, etc.)")
        logger.info(f"Input dataframe of shape {input_df.shape}")

    def _craft_interpreted_columns(self, input_df):
        interpreted_columns = LLMMetricInput.from_df(input_df, self.recipe_desc)
        if (self.recipe_desc.input_format == 'PROMPT_RECIPE'
                and self.input_column_name == prompt_recipe_utils.PROMPT_RECIPE_RAW_QUERY_NAME
                and self.output_column_name == prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME
                and self.context_column_name == prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME):
            interpreted_columns.input = prompt_recipe_utils.try_get_reconstructed_prompt_recipe_input(input_df)
            self.output_df[RECONSTRUCTED_INPUT_NAME] = interpreted_columns.input

            if prompt_recipe_utils.has_raw_response(input_df):
                interpreted_columns.output = prompt_recipe_utils.try_get_parsed_prompt_recipe_output(input_df)
                interpreted_columns.context = prompt_recipe_utils.try_get_parsed_prompt_recipe_context(input_df, has_context_based_metrics(self.recipe_desc.metrics))
            else:
                interpreted_columns.output = pd.Series(dtype=object)
                interpreted_columns.context = pd.Series(dtype=object)

            self.output_df[PARSED_CONTEXT_NAME] = interpreted_columns.context.apply(json.dumps)
            self.output_df[PARSED_OUTPUT_NAME] = interpreted_columns.output

        elif (self.recipe_desc.input_format == 'DATAIKU_ANSWERS'
              and self.input_column_name == dataiku_answers_utils.DATAIKU_ANSWERS_QUESTION_NAME
              and self.output_column_name == dataiku_answers_utils.DATAIKU_ANSWERS_ANSWER_NAME
              and self.context_column_name == dataiku_answers_utils.DATAIKU_ANSWERS_SOURCES_NAME):
            interpreted_columns.context = dataiku_answers_utils.try_get_parsed_dataiku_answer_context(input_df, has_context_based_metrics(self.recipe_desc.metrics))
            self.output_df[PARSED_CONTEXT_NAME] = interpreted_columns.context

        return interpreted_columns

    def _compute_metrics(self, input_df, interpreted_columns):
        if has_ragas_metrics(self.recipe_desc.metrics):
            logger.info("Some metrics using RAGAS will be computed")
            try:
                check_use_ragas_metrics(self.recipe_desc.metrics, self.has_ground_truth, self.has_context, self.recipe_desc.completion_llm_id,
                                        self.recipe_desc.embedding_llm_id, self.recipe_desc.input_format == "PROMPT_RECIPE")
                ragas_metrics_computer = RagasMetricsComputer(
                    self.recipe_desc.completion_llm_id,
                    self.recipe_desc.completion_settings,
                    self.recipe_desc.embedding_llm_id,
                    self.ragas_max_workers,
                    False,
                    can_compute_multimodal_metrics=self.recipe_desc.input_format == "PROMPT_RECIPE")
                ragas_perf, ragas_metric_df = ragas_metrics_computer.compute(interpreted_columns, self.recipe_desc.metrics)
            except LLMEvalException as e:
                raise_or_continue(e, "some RAGAS", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                ragas_perf, ragas_metric_df = create_empty_ragas_metrics(interpreted_columns, self.recipe_desc.metrics)
            self._update_outputs(ragas_perf, ragas_metric_df)

        if has_bert_score(self.recipe_desc.metrics):
            logger.info("BERT Score metrics will be computed")
            try:
                bert_perf, bert_metrics_df = compute_bert_score(interpreted_columns, self.recipe_desc.bert_score_model_type)
            except LLMEvalException as e:
                raise_or_continue(e, "BERT Score", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                bert_perf, bert_metrics_df= bert_score_utils.create_empty_bert_score(interpreted_columns)
            self._update_outputs(bert_perf, bert_metrics_df)

        if has_bleu(self.recipe_desc.metrics):
            logger.info("BLEU metrics will be computed")
            try:
                bleu_perf, bleu_metrics_df = compute_bleu(interpreted_columns, self.recipe_desc.bleu_tokenizer)
            except LLMEvalException as e:
                raise_or_continue(e, "BLEU", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                bleu_perf, bleu_metrics_df = bleu_utils.create_empty_bleu(interpreted_columns)
            self._update_outputs(bleu_perf, bleu_metrics_df)

        if has_rouge(self.recipe_desc.metrics):
            logger.info("ROUGE metrics will be computed")
            try:
                rouge_perf, rouge_metrics_df = compute_rouge(interpreted_columns)
            except LLMEvalException as e:
                raise_or_continue(e, "ROUGE", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                rouge_perf, rouge_metrics_df = rouge_utils.create_empty_rouge(interpreted_columns)
            self._update_outputs(rouge_perf, rouge_metrics_df)

        if can_token_count(self.recipe_desc):
            logger.info("Token count metrics will be computed")
            try:
                token_count, token_count_df = compute_token_count(input_df, self.recipe_desc.input_format)
            except LLMEvalException as e:
                raise_or_continue(e, "Token Count", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                token_count, token_count_df = token_count_utils.create_empty_token_count(interpreted_columns)
            self._update_outputs(token_count, token_count_df)
        else:
            logger.info("Token count metrics won't be computed")

        recipe_params_for_custom_metrics = RecipeParamsForCustomMetrics(self.recipe_desc, self.recipe_desc.completion_settings)
        for custom_metric in self.custom_metrics:
            logger.info("Custom metric: %s will be computed" % custom_metric.name)
            # Custom metrics errors are handled internally, because we want to return errors in the custom_metric_result
            result = compute_custom_metric(custom_metric, input_df, interpreted_columns, self.recipe_desc.fail_on_errors, recipe_params_for_custom_metrics)
            self._update_outputs_custom_metric(custom_metric, result)

    def _get_input_df(self):
        input_dataset = dataiku.Dataset(self.input_dataset_smartname)
        input_schema = input_dataset.read_schema()
        input_df = dataiku.Dataset(self.input_dataset_smartname).get_dataframe(sampling=self.recipe_desc.sampling)
        for column_schema in input_schema:
            column = column_schema['name']
            if column_schema['type'] == 'array' and column in [self.ground_truth_column_name, self.context_column_name]:
                input_df[column] = input_df[column].apply(json.loads)

        return input_df

    def _update_outputs(self, perf_metrics: dict, row_by_row_metrics_df: pd.DataFrame):
        """
        This updates the output df, the metrics df and the performance metrics file written in the model evaluation folder
        :param perf_metrics: A dictionary containing the global metrics, used for the metrics df and the performance of the model evaluation
        :param row_by_row_metrics_df: The dataframe with the metrics computed for each row
        """
        if self.has_output_dataset or self.has_model_evaluation_store:
            self.output_df = pd.concat([self.output_df, row_by_row_metrics_df], axis=1)

        if self.has_model_evaluation_store:
            self.perf['metrics'].update(perf_metrics)

        if self.has_metrics_dataset:
            self.metrics_df = pd.concat([self.metrics_df,  pd.DataFrame.from_dict([perf_metrics])], axis=1)

    def _update_outputs_custom_metric(self, custom_metric, custom_metrics_result: dict):
        """
        This updates the output df, the metrics df and the performance metrics file written in the model evaluation folder
        :param custom_metric: a custom metric definition
        :param custom_metrics_result: a dict looking like a RowByRowCustomMetricSuccess (didSucceed, value, rowByRowValues)
        """
        if self.has_output_dataset or self.has_model_evaluation_store:
            self.output_df[custom_metric.name] = custom_metrics_result.get('rowByRowValues', None)

        if self.has_model_evaluation_store:
            self.perf['metrics'].setdefault('customMetricsResults', [])
            self.perf['metrics']['customMetricsResults'].append(custom_metrics_result)

        if self.has_metrics_dataset:
            self.metrics_df[custom_metric.name] = [custom_metrics_result.get('value', None)] # TODO : what about append ?


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    read_dku_env_and_set()

    run_folder_path = sys.argv[1]
    model_evaluation_folder_path = sys.argv[2]
    input_dataset_smartname = sys.argv[3]
    output_dataset_smartname = sys.argv[4]
    metrics_dataset_smartname = sys.argv[5]
    ragas_max_workers = sys.argv[6]

    with ErrorMonitoringWrapper():
        LLMEvaluationRecipe.build(run_folder_path, model_evaluation_folder_path, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname,
                            ragas_max_workers).run()
