import json
import logging
import sys
from typing import Any, Dict, Optional

import pandas as pd

from dataiku.base.folder_context import FolderContext, build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging
from dataiku.llm.evaluation.genai_custom_evaluation_metric  import GenAiCustomEvaluationMetric
from dataiku.llm.evaluation.exceptions import AgentEvalException
from dataiku.llm.evaluation.genai_evaluation_recipe import PARSED_OUTPUT_NAME, PARSED_TOOL_CALLS_NAME, PARSED_TRAJECTORY_NAME, RECONSTRUCTED_INPUT_NAME, GenAIEvaluationRecipe
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataiku.llm.evaluation.utils import custom_metrics_utils, execution_time_utils, prompt_recipe_utils, tool_call_utils, tool_statistics_utils, bert_score_utils, failure_utils

from importlib.metadata import version

ragas_version = version("ragas")
if not ragas_version:
    raise ImportError("ragas package is missing.")
if ragas_version.startswith("0.1"):
    raise ImportError("Agent Evaluation needs ragas >= 0.2.12. Please update your Code Environment.")
elif ragas_version.startswith("0.2"):
    from dataiku.llm.evaluation.utils.ragas import ragas_utils_0_2_12
else:
    raise ImportError(f"Version of ragas {ragas_version} is not supported (only 0.1.X and 0.2.X).")

logger = logging.getLogger(__name__)


class AgentEvaluationRecipe(GenAIEvaluationRecipe):

    def __init__(
        self,
        recipe_desc: GenAIEvalRecipeDesc,
        run_folder: FolderContext,
        model_evaluation_folder: Optional[FolderContext],
        input_dataset_smartname: str,
        output_dataset_smartname: str,
        metrics_dataset_smartname: str,
        ragas_max_workers: int
    ) -> None:
        super(AgentEvaluationRecipe, self).__init__(
            recipe_desc,
            run_folder,
            model_evaluation_folder,
            input_dataset_smartname,
            output_dataset_smartname,
            metrics_dataset_smartname,
            ragas_max_workers
        )

    @classmethod
    def build(
        cls,
        run_folder_path: str,
        model_evaluation_folder_path: str,
        input_dataset_smartname: str,
        output_dataset_smartname: str,
        metrics_dataset_smartname: str,
        ragas_max_workers: int
    ) -> "AgentEvaluationRecipe":
        run_folder: FolderContext = build_folder_context(run_folder_path)
        model_evaluation_folder: Optional[FolderContext] = build_folder_context(model_evaluation_folder_path) if model_evaluation_folder_path else None
        recipe_desc = GenAIEvalRecipeDesc(run_folder.read_json("desc.json"))
        return cls(
            recipe_desc,
            run_folder,
            model_evaluation_folder,
            input_dataset_smartname,
            output_dataset_smartname,
            metrics_dataset_smartname,
            ragas_max_workers
        )

    @classmethod
    def build_for_test(cls, recipe_desc: GenAIEvalRecipeDesc, input_dataset_smartname: str) -> "AgentEvaluationRecipe":
        """
        Creates an ``AgentEvaluationRecipe`` instance for testing user-defined custom metrics through ``test_custom_metric``.
        :param recipe_desc: Recipe description containing the evaluation parameters.
        :param input_dataset_smartname: Name of the input dataset to use for testing.
        :return: An ``AgentEvaluationRecipe`` instance configured for custom metric testing, without writing to datasets or model evaluation stores.
        """
        # FIXME: typing warning here due bad use of None
        return cls(recipe_desc, None, None, input_dataset_smartname, None, None, 0)

    def test_custom_metric(self, custom_metric: GenAiCustomEvaluationMetric) -> Dict[str, Any]:
        """
        Tests a user-defined ``custom_metric`` on the first 5 lines of the input dataset.
        Allows users to validate the correctness of their custom metric implementation.
        Does not write any results to model evaluation stores or datasets.
        :param custom_metric: The custom metric to be tested.
        :return: A dictionary containing the test results of ``custom_metric`` on the 5 first lines, looks like a ``RowByRowCustomMetricSuccess`` (value, rowByRowValues, didSucceed, metric).
        """
        input_df = self._get_input_df()
        self._run_sanity_checks(input_df)
        sample_df = input_df[0:5]
        recipe_params_for_custom_metrics = custom_metrics_utils.RecipeParamsForCustomMetrics(self.recipe_desc, self.recipe_desc.completion_settings)
        metric_inputs = self._craft_metric_inputs(sample_df)
        logger.info(f"Custom metric: {custom_metric.name} will be tested")
        result = custom_metrics_utils.compute_custom_metric(custom_metric, sample_df, metric_inputs, False, recipe_params_for_custom_metrics)
        return result

    def _get_perf_file_name(self) -> str:
        return "agent_perf.json"

    def _run_sanity_checks(self, input_df: pd.DataFrame) -> None:
        if self.input_column_name is not None and self.input_column_name not in input_df.columns:
            raise AgentEvalException(f"Input column {self.input_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.output_column_name is not None and self.output_column_name not in input_df.columns:
            raise AgentEvalException(f"Output column {self.output_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.ground_truth_column_name is not None and self.ground_truth_column_name not in input_df.columns:
            raise AgentEvalException(f"Ground truth column {self.ground_truth_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.actual_tool_calls_column_name is not None and self.actual_tool_calls_column_name not in input_df.columns:
            raise AgentEvalException(f"Actual tool calls column {self.actual_tool_calls_column_name} is not a column of {self.input_dataset_smartname}.")
        if self.reference_tool_calls_column_name is not None and self.reference_tool_calls_column_name not in input_df.columns:
            raise AgentEvalException(f"Reference tool calls column {self.reference_tool_calls_column_name} is not a column of {self.input_dataset_smartname}.")
        if len(input_df) == 0:
            raise AgentEvalException("The evaluation dataset is empty. Please check your recipe configuration (dataset, sampling parameters, etc.)")
        logger.info(f"Input dataframe of shape {input_df.shape}")

    def _craft_metric_inputs(self, input_df: pd.DataFrame) -> GenAIMetricInput:
        metric_inputs = GenAIMetricInput.from_df(input_df, self.recipe_desc)
        if self.recipe_desc.input_format == "PROMPT_RECIPE":
            metric_inputs.input = prompt_recipe_utils.try_get_reconstructed_prompt_recipe_input(input_df)
            metric_inputs.reference_tool_calls = input_df.get(self.reference_tool_calls_column_name, pd.Series(dtype=object))
            self.output_df[RECONSTRUCTED_INPUT_NAME] = metric_inputs.input

            if prompt_recipe_utils.has_raw_response(input_df):
                metric_inputs.output = prompt_recipe_utils.try_get_parsed_prompt_recipe_output(input_df)
                metric_inputs.actual_tool_calls = prompt_recipe_utils.try_get_parsed_prompt_recipe_tool_calls(input_df, True)
            else:
                metric_inputs.output = pd.Series(dtype=object)
                metric_inputs.actual_tool_calls = pd.Series(dtype=object)


            self.output_df[PARSED_OUTPUT_NAME] = metric_inputs.output
            self.output_df[PARSED_TOOL_CALLS_NAME] = metric_inputs.actual_tool_calls.apply(json.dumps)
            self.output_df[PARSED_TRAJECTORY_NAME] = prompt_recipe_utils.try_get_parsed_prompt_recipe_trajectory(input_df).apply(json.dumps)

        return metric_inputs

    def _compute_and_update_metrics(self, input_df: pd.DataFrame, metric_inputs: GenAIMetricInput) -> None:
        if execution_time_utils.can_compute_total_execution_time(self.recipe_desc):
            logger.info("Total execution time will be computed")
            try:
                p95_total_execution_time, total_execution_time_df = execution_time_utils.compute_total_execution_time(input_df)
            except AgentEvalException as e:
                p95_total_execution_time, total_execution_time_df = execution_time_utils.create_empty_p95_total_execution_time()
            self._update_outputs(p95_total_execution_time, total_execution_time_df)

        if tool_statistics_utils.can_tool_statistics(self.recipe_desc):
            logger.info("Tools execution statistics will be computed")
            try:
                tool_statistics, tool_statistics_df = tool_statistics_utils.compute_tool_statistics(input_df, metric_inputs)
            except AgentEvalException as e:
                failure_utils.raise_or_continue(e, "Tools execution statistics", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                tool_statistics, tool_statistics_df = tool_statistics_utils.create_empty_tool_statistics(metric_inputs)
            self._update_outputs(tool_statistics, tool_statistics_df)

        if tool_call_utils.has_tool_call_metrics(self.recipe_desc.metrics):
            logger.info("Recipe description includes tool call metrics, starting to compute them")
            try:
                tool_call_metrics, tool_call_metrics_df = tool_call_utils.compute_tool_call_metrics(metric_inputs, self.recipe_desc.metrics)
            except AgentEvalException as e:
                failure_utils.raise_or_continue(e, "some tool call metric", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                tool_call_metrics, tool_call_metrics_df = tool_call_utils.create_empty_tool_call_metrics(metric_inputs, self.recipe_desc.metrics)
            self._update_outputs(tool_call_metrics, tool_call_metrics_df)

        if (ragas_utils_0_2_12.has_ragas_llm_metrics(self.recipe_desc.metrics)
                or ragas_utils_0_2_12.has_ragas_agent_metrics(self.recipe_desc.metrics)):
            logger.info("Agent goal accuracy will be computed")
            compute_agent_metrics = False
            compute_llm_metrics = False
            try:
                compute_agent_metrics = ragas_utils_0_2_12.has_ragas_agent_metrics(self.recipe_desc.metrics)
                compute_llm_metrics = ragas_utils_0_2_12.has_ragas_llm_metrics(self.recipe_desc.metrics)
                ragas_utils_0_2_12.check_use_ragas_metrics(
                    self.recipe_desc.metrics,
                    self.has_ground_truth,
                    self.has_context,
                    self.has_actual_tool_calls,
                    self.recipe_desc.completion_llm_id,
                    self.recipe_desc.embedding_llm_id,
                    self.recipe_desc.input_format == "PROMPT_RECIPE"
                )
                ragas_metrics_computer = ragas_utils_0_2_12.RagasMetricsComputer(
                    self.recipe_desc.completion_llm_id,
                    self.recipe_desc.completion_settings,
                    self.recipe_desc.embedding_llm_id,
                    self.ragas_max_workers,
                    False,
                    can_compute_multimodal_metrics=self.recipe_desc.input_format == "PROMPT_RECIPE"
                )
                if compute_agent_metrics:
                    agent_goal_accuracy, agent_goal_accuracy_df = ragas_metrics_computer.compute_agent_metrics(metric_inputs, self.recipe_desc.metrics, self.has_ground_truth)
                if compute_llm_metrics:
                    ragas_perf, ragas_metric_df = ragas_metrics_computer.compute_llm_metrics(metric_inputs, self.recipe_desc.metrics)
            except AgentEvalException as e:
                failure_utils.raise_or_continue(e, "some RAGAS", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                if compute_agent_metrics:
                    agent_goal_accuracy, agent_goal_accuracy_df = ragas_utils_0_2_12.create_empty_ragas_agent_metrics(metric_inputs, self.recipe_desc.metrics)
                if compute_llm_metrics:
                    ragas_perf, ragas_metric_df = ragas_utils_0_2_12.create_empty_ragas_llm_metrics(metric_inputs, self.recipe_desc.metrics)

            if compute_agent_metrics:
                self._update_outputs(agent_goal_accuracy, agent_goal_accuracy_df)
            if compute_llm_metrics:
                self._update_outputs(ragas_perf, ragas_metric_df)

        if bert_score_utils.has_bert_score(self.recipe_desc.metrics):
            logger.info("BERT Score metrics will be computed")
            try:
                bert_perf, bert_metrics_df = bert_score_utils.compute_bert_score(metric_inputs, self.recipe_desc.bert_score_model_type)
            except AgentEvalException as e:
                failure_utils.raise_or_continue(e, "BERT Score", self.recipe_desc.fail_on_errors, self.recipe_desc.input_format)
                bert_perf, bert_metrics_df = bert_score_utils.create_empty_bert_score(metric_inputs)
            self._update_outputs(bert_perf, bert_metrics_df)

        recipe_params_for_custom_metrics = custom_metrics_utils.RecipeParamsForCustomMetrics(self.recipe_desc, self.recipe_desc.completion_settings)
        for custom_metric in self.custom_metrics:
            logger.info(f"Custom metric: {custom_metric.name} will be computed")
            # Custom metrics errors are handled internally, because we want to return errors in custom_metric_result.
            result = custom_metrics_utils.compute_custom_metric(custom_metric, input_df, metric_inputs, self.recipe_desc.fail_on_errors, recipe_params_for_custom_metrics)
            self._update_outputs_custom_measures(custom_metric.name, result)

        for custom_trait in self.custom_traits:
            logger.info(f"Custom trait: {custom_trait.name} will be computed")
            # Custom metrics errors are handled internally, because we want to return errors in custom_metric_result.
            result = custom_metrics_utils.compute_custom_trait(custom_trait, input_df, metric_inputs, self.recipe_desc.fail_on_errors, recipe_params_for_custom_metrics)
            self._update_outputs_custom_measures(custom_trait.name, result)

    def _update_outputs_custom_measures(self, custom_measure_name: str, custom_metrics_result: Dict[str, Any]) -> None:
        """
        Updates the output DataFrame, the metrics DataFrame and the performance metrics file in the model evaluation folder.
        :param custom_metric: Custom metric definition.
        :param custom_metrics_result: Result dictionary looking like a ``RowByRowCustomMetricSuccess`` (value, rowByRowValues, didSucceed, metric).
        """
        if self.has_output_dataset or self.has_model_evaluation_store:
            self.output_df[custom_measure_name] = custom_metrics_result.get("rowByRowValues", custom_metrics_result.get("rowByRowStringValues"))

        if self.has_model_evaluation_store:
            self.perf["metrics"].setdefault("customMetricsResults", [])
            self.perf["metrics"]["customMetricsResults"].append(custom_metrics_result)

        if self.has_metrics_dataset:
            self.metrics_df[custom_measure_name] = [custom_metrics_result.get("value", None)]  # TODO: what about append ?


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s")
    read_dku_env_and_set()

    run_folder_path = sys.argv[1]
    model_evaluation_folder_path = sys.argv[2]
    input_dataset_smartname = sys.argv[3]
    output_dataset_smartname = sys.argv[4]
    metrics_dataset_smartname = sys.argv[5]
    ragas_max_workers = int(sys.argv[6])

    with ErrorMonitoringWrapper():
        agent_evaluation_recipe = AgentEvaluationRecipe.build(
            run_folder_path,
            model_evaluation_folder_path,
            input_dataset_smartname,
            output_dataset_smartname,
            metrics_dataset_smartname,
            ragas_max_workers
        )
        agent_evaluation_recipe.run()
