# encoding: utf-8
"""
Execute a LLM Evaluation recipe. Must be called in a Flow environment
"""

import json
import logging
import sys
from typing import Any, Dict, Optional

import pandas as pd

from dataiku.base.folder_context import FolderContext, build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging
from dataiku.llm.evaluation.genai_custom_evaluation_metric  import GenAiCustomEvaluationMetric
from dataiku.llm.evaluation.exceptions import LLMEvalException
from dataiku.llm.evaluation.genai_evaluation_recipe import PARSED_CONTEXT_NAME, PARSED_OUTPUT_NAME, RECONSTRUCTED_INPUT_NAME, GenAIEvaluationRecipe
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataiku.llm.evaluation.utils import bert_score_utils, bleu_utils, custom_metrics_utils, dataiku_answers_utils, failure_utils, prompt_recipe_utils, rouge_utils, token_count_utils

from importlib.metadata import version
ragas_version = version("ragas")
if not ragas_version:
    raise ImportError("ragas package is missing.")
if ragas_version.startswith("0.1"):
    from dataiku.llm.evaluation.utils.ragas.ragas_utils_0_1_10 import check_use_ragas_metrics, has_ragas_llm_metrics, RagasMetricsComputer, \
        has_context_based_metrics, create_empty_ragas_llm_metrics
elif ragas_version.startswith("0.2"):
    from dataiku.llm.evaluation.utils.ragas.ragas_utils_0_2_12 import check_use_ragas_metrics, has_ragas_llm_metrics, RagasMetricsComputer, \
        has_context_based_metrics, create_empty_ragas_llm_metrics
else:
    raise ImportError(f"Version of ragas {ragas_version} is not supported (only 0.1.X and 0.2.X).")

logger = logging.getLogger(__name__)


class LLMEvaluationRecipe(GenAIEvaluationRecipe):

    def __init__(
        self,
        recipe_desc: GenAIEvalRecipeDesc,
        run_folder: FolderContext,
        model_evaluation_folder: Optional[FolderContext],
        input_dataset_smartname: str,
        output_dataset_smartname: str,
        metrics_dataset_smartname: str,
        ragas_max_workers: int
    ) -> None:
        super(LLMEvaluationRecipe, self).__init__(
            recipe_desc,
            run_folder,
            model_evaluation_folder,
            input_dataset_smartname,
            output_dataset_smartname,
            metrics_dataset_smartname,
            ragas_max_workers
        )

    @classmethod
    def build(
        cls,
        run_folder_path: str,
        model_evaluation_folder_path: str,
        input_dataset_smartname: str,
        output_dataset_smartname: str,
        metrics_dataset_smartname: str,
        ragas_max_workers: int
    ) -> "LLMEvaluationRecipe":
        run_folder: FolderContext = build_folder_context(run_folder_path)
        model_evaluation_folder: Optional[FolderContext] = build_folder_context(model_evaluation_folder_path) if model_evaluation_folder_path else None
        recipe_desc = GenAIEvalRecipeDesc(run_folder.read_json("desc.json"))
        return cls(
            recipe_desc,
            run_folder,
            model_evaluation_folder,
            input_dataset_smartname,
            output_dataset_smartname,
            metrics_dataset_smartname,
            ragas_max_workers
        )

    @classmethod
    def build_for_test(cls, recipe_desc: GenAIEvalRecipeDesc, input_dataset_smartname: str) -> "LLMEvaluationRecipe":
        """
        Creates an ``LLMEvaluationRecipe`` instance for testing user-defined custom metrics through ``test_custom_metric``.
        :param recipe_desc: Recipe description containing the evaluation parameters.
        :param input_dataset_smartname: Name of the input dataset to use for testing.
        :return: An ``LLMEvaluationRecipe`` instance configured for custom metric testing, without writing to datasets or model evaluation stores.
        """
        # FIXME: typing warning here due bad use of None
        return cls(recipe_desc, None, None, input_dataset_smartname, None, None, 0)

    def test_custom_metric(self, custom_metric: GenAiCustomEvaluationMetric) -> Dict[str, Any]:
        """
        Tests a user-defined ``custom_metric`` on the first 5 lines of the input dataset.
        Allows users to validate the correctness of their custom metric implementation.
        Does not write any results to model evaluation stores or datasets.
        :param custom_metric: The custom metric to be tested.
        :return: A dictionary containing the test results of ``custom_metric`` on the 5 first lines, looks like a ``RowByRowCustomMetricSuccess`` (value, rowByRowValues, didSucceed, metric).
        """
        input_df = self._get_input_df()
        self._run_sanity_checks(input_df)
        sample_df = input_df[0:5]
        recipe_params_for_custom_metrics = custom_metrics_utils.RecipeParamsForCustomMetrics(self.recipe_desc, self.recipe_desc.completion_settings)
        metric_inputs = self._craft_metric_inputs(sample_df)
        logger.info(f"Custom metric: {custom_metric.name} will be tested")
        result = custom_metrics_utils.compute_custom_metric(custom_metric, sample_df, metric_inputs, False, recipe_params_for_custom_metrics)
        return result

    def _get_perf_file_name(self) -> str:
        return "llm_perf.json"

    def _run_sanity_checks(self, input_df: pd.DataFrame) -> None:
        if self.input_column_name is not None and self.input_column_name not in input_df.columns:
            raise LLMEvalException(f"Input column {self.input_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.output_column_name is not None and self.output_column_name not in input_df.columns:
            raise LLMEvalException(f"Output column {self.output_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.ground_truth_column_name is not None and self.ground_truth_column_name not in input_df.columns:
            raise LLMEvalException(f"Ground truth column {self.ground_truth_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.context_column_name is not None and self.context_column_name not in input_df.columns:
            raise LLMEvalException(f"Context column {self.context_column_name} is not a column of {self.input_dataset_smartname}.")
        if len(input_df) == 0:
            raise LLMEvalException("The evaluation dataset is empty. Please check your recipe configuration (dataset, sampling parameters, etc.)")
        logger.info(f"Input dataframe of shape {input_df.shape}")

    def _craft_metric_inputs(self, input_df: pd.DataFrame) -> GenAIMetricInput:
        metric_inputs = GenAIMetricInput.from_df(input_df, self.recipe_desc)
        if self.recipe_desc.input_format == "PROMPT_RECIPE":
            metric_inputs.input = prompt_recipe_utils.try_get_reconstructed_prompt_recipe_input(input_df)
            self.output_df[RECONSTRUCTED_INPUT_NAME] = metric_inputs.input

            if prompt_recipe_utils.has_raw_response(input_df):
                metric_inputs.output = prompt_recipe_utils.try_get_parsed_prompt_recipe_output(input_df)
                metric_inputs.context = prompt_recipe_utils.try_get_parsed_prompt_recipe_context(input_df, has_context_based_metrics(self.recipe_desc.metrics))
            else:
                metric_inputs.output = pd.Series(dtype=object)
                metric_inputs.context = pd.Series(dtype=object)

            self.output_df[PARSED_CONTEXT_NAME] = metric_inputs.context.apply(json.dumps)
            self.output_df[PARSED_OUTPUT_NAME] = metric_inputs.output

        elif self.recipe_desc.input_format == "DATAIKU_ANSWERS":
            metric_inputs.context = dataiku_answers_utils.try_get_parsed_dataiku_answer_context(input_df, has_context_based_metrics(self.recipe_desc.metrics))
            self.output_df[PARSED_CONTEXT_NAME] = metric_inputs.context

        return metric_inputs

    def _compute_and_update_metrics(self, input_df: pd.DataFrame, metric_inputs: GenAIMetricInput) -> None:
        if has_ragas_llm_metrics(self.recipe_desc.metrics):
            logger.info("Some metrics using RAGAS will be computed")
            try:
                check_use_ragas_metrics(
                    self.recipe_desc.metrics,
                    self.has_ground_truth,
                    self.has_context,
                    False,
                    self.recipe_desc.completion_llm_id,
                    self.recipe_desc.embedding_llm_id,
                    self.recipe_desc.input_format == "PROMPT_RECIPE"
                )
                ragas_metrics_computer = RagasMetricsComputer(
                    self.recipe_desc.completion_llm_id,
                    self.recipe_desc.completion_settings,
                    self.recipe_desc.embedding_llm_id,
                    self.ragas_max_workers,
                    False,
                    can_compute_multimodal_metrics=self.recipe_desc.input_format == "PROMPT_RECIPE")
                ragas_perf, ragas_metric_df = ragas_metrics_computer.compute_llm_metrics(metric_inputs, self.recipe_desc.metrics)
            except LLMEvalException as e:
                failure_utils.raise_or_continue(e, "some RAGAS", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                ragas_perf, ragas_metric_df = create_empty_ragas_llm_metrics(metric_inputs, self.recipe_desc.metrics)
            self._update_outputs(ragas_perf, ragas_metric_df)

        if bert_score_utils.has_bert_score(self.recipe_desc.metrics):
            logger.info("BERT Score metrics will be computed")
            try:
                bert_perf, bert_metrics_df = bert_score_utils.compute_bert_score(metric_inputs, self.recipe_desc.bert_score_model_type)
            except LLMEvalException as e:
                failure_utils.raise_or_continue(e, "BERT Score", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                bert_perf, bert_metrics_df = bert_score_utils.create_empty_bert_score(metric_inputs)
            self._update_outputs(bert_perf, bert_metrics_df)

        if bleu_utils.has_bleu(self.recipe_desc.metrics):
            logger.info("BLEU metrics will be computed")
            try:
                bleu_perf, bleu_metrics_df = bleu_utils.compute_bleu(metric_inputs, self.recipe_desc.bleu_tokenizer)
            except LLMEvalException as e:
                failure_utils.raise_or_continue(e, "BLEU", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                bleu_perf, bleu_metrics_df = bleu_utils.create_empty_bleu(metric_inputs)
            self._update_outputs(bleu_perf, bleu_metrics_df)

        if rouge_utils.has_rouge(self.recipe_desc.metrics):
            logger.info("ROUGE metrics will be computed")
            try:
                rouge_perf, rouge_metrics_df = rouge_utils.compute_rouge(metric_inputs)
            except LLMEvalException as e:
                failure_utils.raise_or_continue(e, "ROUGE", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                rouge_perf, rouge_metrics_df = rouge_utils.create_empty_rouge(metric_inputs)
            self._update_outputs(rouge_perf, rouge_metrics_df)

        if token_count_utils.can_token_count(self.recipe_desc):
            logger.info("Token count metrics will be computed")
            try:
                token_count, token_count_df = token_count_utils.compute_token_count(input_df, self.recipe_desc.input_format)
            except LLMEvalException as e:
                failure_utils.raise_or_continue(e, "Token Count", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                token_count, token_count_df = token_count_utils.create_empty_token_count(metric_inputs)
            self._update_outputs(token_count, token_count_df)
        else:
            logger.info("Token count metrics won't be computed")

        recipe_params_for_custom_metrics = custom_metrics_utils.RecipeParamsForCustomMetrics(self.recipe_desc, self.recipe_desc.completion_settings)
        for custom_metric in self.custom_metrics:
            logger.info(f"Custom metric: {custom_metric.name} will be computed")
            # Custom metrics errors are handled internally, because we want to return errors in custom_metric_result.
            result = custom_metrics_utils.compute_custom_metric(custom_metric, input_df, metric_inputs, self.recipe_desc.fail_on_errors, recipe_params_for_custom_metrics)
            self._update_outputs_custom_metric(custom_metric, result)

    def _update_outputs_custom_metric(self, custom_metric: GenAiCustomEvaluationMetric, custom_metrics_result: Dict[str, Any]) -> None:
        """
        Updates the output DataFrame, the metrics DataFrame and the performance metrics file in the model evaluation folder.
        :param custom_metric: Custom metric definition.
        :param custom_metrics_result: Result dictionary looking like a ``RowByRowCustomMetricSuccess`` (value, rowByRowValues, didSucceed, metric).
        """
        if self.has_output_dataset or self.has_model_evaluation_store:
            self.output_df[custom_metric.name] = custom_metrics_result.get("rowByRowValues", None)

        if self.has_model_evaluation_store:
            self.perf["metrics"].setdefault("customMetricsResults", [])
            self.perf["metrics"]["customMetricsResults"].append(custom_metrics_result)

        if self.has_metrics_dataset:
            self.metrics_df[custom_metric.name] = [custom_metrics_result.get("value", None)]  # TODO : what about append ?


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s")
    read_dku_env_and_set()

    run_folder_path = sys.argv[1]
    model_evaluation_folder_path = sys.argv[2]
    input_dataset_smartname = sys.argv[3]
    output_dataset_smartname = sys.argv[4]
    metrics_dataset_smartname = sys.argv[5]
    ragas_max_workers = int(sys.argv[6])

    with ErrorMonitoringWrapper():
        llm_evaluation_recipe = LLMEvaluationRecipe.build(
            run_folder_path,
            model_evaluation_folder_path,
            input_dataset_smartname,
            output_dataset_smartname,
            metrics_dataset_smartname,
            ragas_max_workers
        )
        llm_evaluation_recipe.run()
