import ast
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

import dataiku
import pandas as pd

from dataiku import Dataset
from dataiku.base.folder_context import FolderContext
from dataiku.core import schema_handling
from dataiku.doctor import step_constants, utils
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.evaluation.base import sample_and_store_dataframe
from dataiku.doctor.utils.listener import DiagOnlyContext, ProgressListener
from dataiku.llm.evaluation.exceptions import GenAIEvalException, ToolCallException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataiku.llm.evaluation.utils import dataiku_answers_utils
from dataiku.llm.evaluation.utils.common import PROMPT_RECIPE_RAW_QUERY_NAME, PROMPT_RECIPE_RAW_RESPONSE_NAME


# When parsing a "raw response" json from a prompt recipe, extract the output to this column
PARSED_OUTPUT_NAME = "dkuParsedOutput"  # Keep in sync with GenAIEvaluationUtils.java
# When parsing a "raw response" json from a prompt recipe or answer plugin, extract the context to this column (as an array of string)
PARSED_CONTEXT_NAME = "dkuParsedContexts"  # Keep in sync with GenAIEvaluationUtils.java
# "Full" input from the prompt recipe, with system prompt, examples, etc...
RECONSTRUCTED_INPUT_NAME = "dkuReconstructedInput"  # Keep in sync with GenAIEvaluationUtils.java
# When parsing a "raw response" json from a prompt recipe or answer plugin, extract the list of tool calls to this column (as an array of string)
PARSED_TOOL_CALLS_NAME = "dkuParsedToolCalls"  # Keep in sync with GenAIEvaluationUtils.java
# When parsing a "raw response" json from a prompt recipe, extract the trajectory to this column
PARSED_TRAJECTORY_NAME = "dkuParsedTrajectory"  # Keep in sync with GenAIEvaluationUtils.java

logger = logging.getLogger(__name__)


class GenAIEvaluationRecipe(ABC):
    def __init__(
        self,
        recipe_desc: GenAIEvalRecipeDesc,
        run_folder: FolderContext,
        model_evaluation_folder: Optional[FolderContext],
        input_dataset_smartname: str,
        output_dataset_smartname: str,
        metrics_dataset_smartname: str,
        ragas_max_workers: int,
    ) -> None:
        self.run_folder = run_folder
        self.model_evaluation_folder = model_evaluation_folder

        self.input_dataset_smartname = input_dataset_smartname
        self.output_dataset_smartname = output_dataset_smartname
        self.metrics_dataset_smartname = metrics_dataset_smartname

        self.recipe_desc = recipe_desc

        # Define the default input, output, and context columns of recipe_desc.
        # If you create the recipe with the public API, they are not automatically defined so we infer them.
        if self.recipe_desc.input_format in ["PROMPT_RECIPE", "DATAIKU_ANSWERS"]:
            if self.recipe_desc.input_format == "PROMPT_RECIPE":
                expected_columns = {
                    "input_column_name": PROMPT_RECIPE_RAW_QUERY_NAME,
                    "output_column_name": PROMPT_RECIPE_RAW_RESPONSE_NAME,
                    "context_column_name": PROMPT_RECIPE_RAW_RESPONSE_NAME,
                    "actual_tool_calls_column_name": PROMPT_RECIPE_RAW_RESPONSE_NAME
                }
            else:  # DATAIKU_ANSWERS
                expected_columns = {
                    "input_column_name": dataiku_answers_utils.DATAIKU_ANSWERS_QUESTION_NAME,
                    "output_column_name": dataiku_answers_utils.DATAIKU_ANSWERS_ANSWER_NAME,
                    "context_column_name": dataiku_answers_utils.DATAIKU_ANSWERS_SOURCES_NAME
                }
            for column_name, expected_value in expected_columns.items():
                actual_value = getattr(recipe_desc, column_name, None)
                if actual_value is not None and actual_value != expected_value:
                    raise GenAIEvalException(f"{column_name} should be set to '{expected_value}' instead of '{actual_value}'.")
                setattr(recipe_desc, column_name, expected_value)

        self.input_column_name = recipe_desc.input_column_name
        self.output_column_name = recipe_desc.output_column_name
        self.ground_truth_column_name = recipe_desc.ground_truth_column_name
        self.context_column_name = recipe_desc.context_column_name
        self.actual_tool_calls_column_name = recipe_desc.actual_tool_calls_column_name
        self.reference_tool_calls_column_name = recipe_desc.reference_tool_calls_column_name
        self.custom_metrics = recipe_desc.custom_metrics
        self.custom_traits = recipe_desc.custom_traits

        self.has_ground_truth = recipe_desc.ground_truth_column_name is not None
        self.has_context = recipe_desc.context_column_name is not None
        self.has_actual_tool_calls = recipe_desc.actual_tool_calls_column_name is not None
        self.has_reference_tool_calls = recipe_desc.reference_tool_calls_column_name is not None
        self.has_model_evaluation_store = self.model_evaluation_folder is not None
        self.has_output_dataset = self.output_dataset_smartname is not None and len(self.output_dataset_smartname) > 0
        self.has_metrics_dataset = self.metrics_dataset_smartname is not None and len(self.metrics_dataset_smartname) > 0
        self.ragas_max_workers = ragas_max_workers

        self.output_df = pd.DataFrame()
        self.metrics_df = pd.DataFrame.from_dict({'date': [utils.get_datetime_now_utc()]})

        self.perf = {'metrics':{}}

    def run(self) -> None:
        default_diagnostics.register_evaluation_callbacks()
        listener = ProgressListener(context=DiagOnlyContext(self.run_folder))

        input_df = self._get_input_df()
        row_count = len(input_df)
        self.metrics_df["sampleRowCount"] = row_count
        self.perf["metrics"]["sampleRowCount"] = row_count

        self._run_sanity_checks(input_df)

        if self.has_output_dataset or self.has_model_evaluation_store:
            self.output_df = input_df.copy()

        with listener.push_step(step_constants.ProcessingStep.STEP_EVAL_PROCESSING):
            metric_inputs = self._craft_metric_inputs(input_df)

        with listener.push_step(step_constants.ProcessingStep.STEP_EVAL_PROCESSING):
            self._compute_and_update_metrics(input_df, metric_inputs)

        if self.has_model_evaluation_store:
            logger.info("Writing performance in ME folder")
            self.model_evaluation_folder.write_json(self._get_perf_file_name(), self.perf)
            evaluation_file = "_evaluation.json"
            evaluation = self.model_evaluation_folder.read_json(evaluation_file)
            evaluation["nbEvaluationRows"] = input_df.shape[0]
            self.model_evaluation_folder.write_json(evaluation_file, evaluation)
            output_df_schema = {"columns": schema_handling.get_schema_from_df(self.output_df)}
            sample_and_store_dataframe(
                self.model_evaluation_folder,
                self.output_df,
                output_df_schema,
                filename="sample_with_metrics.csv.gz",
                schema_filename="sample_with_metrics_schema.json",
                limit_sampling=True
            )

        if self.has_output_dataset:
            output_dataset = Dataset(self.output_dataset_smartname)
            logger.info("Writing output dataset")
            output_dataset.write_dataframe(self.output_df)

        if self.has_metrics_dataset:
            metrics_dataset = Dataset(self.metrics_dataset_smartname)
            logger.info("Writing metrics data")
            metrics_dataset.write_dataframe(self.metrics_df)

    def _get_input_df(self) -> pd.DataFrame:
        def json_load_array_if_not_empty(cell: Any, default_value: List[str]) -> Any:
            # Considering nulls and empty strings as empty
            if pd.isna(cell):
                return default_value

            if isinstance(cell, str) and not cell.strip():
                return default_value

            return json.loads(cell)

        def parse_json_column(df: pd.DataFrame, column_name: str, default_value: List[str]) -> List[List[Any]]:
            processed_column = []
            for index, value in df[column_name].items():
                try:
                    processed_value = json_load_array_if_not_empty(value, default_value)
                    if not isinstance(processed_value, list):
                        raise TypeError
                    processed_column.append(processed_value)

                except json.JSONDecodeError as e:
                    raise GenAIEvalException(f"Error parsing as JSON '{column_name}' at index [{index}]: '{str(e)}'. Cell value is: '{value}'")
                except TypeError:
                    raise GenAIEvalException(f"Column '{column_name}' at index [{index}] was parsed as JSON but the result is not a list. Cell value is: '{value}'")

            return processed_column

        input_dataset = dataiku.Dataset(self.input_dataset_smartname)
        input_schema = input_dataset.read_schema()
        input_df = dataiku.Dataset(self.input_dataset_smartname).get_dataframe(sampling=self.recipe_desc.sampling, infer_with_pandas=False)
        if self.recipe_desc.input_format == "PROMPT_RECIPE":
            tool_call_input_columns = [self.reference_tool_calls_column_name]
        else:
            tool_call_input_columns = [self.reference_tool_calls_column_name, self.actual_tool_calls_column_name]

        for column_schema in input_schema:
            column = column_schema["name"]
            if column_schema["type"] == "array" and column in [self.ground_truth_column_name, self.context_column_name, self.reference_tool_calls_column_name, self.actual_tool_calls_column_name]:
                logger.info(f"Loading column {column}")
                input_df[column] = parse_json_column(input_df, column, [])
            elif column_schema["type"] != "array" and column in tool_call_input_columns:
                try:
                    input_df[column] = input_df[column].apply(ast.literal_eval)
                    if len(input_df[column]) != 0 and not isinstance(input_df[column].iloc[0], list):
                        raise ToolCallException("Row did not parse into an array")
                except Exception as e:
                    raise ToolCallException(
                        f"The tool calls column '{column}' must be of type array of strings or array of dicts, got: {input_df[column].iloc[0] if len(input_df[column]) != 0 else 'EMPTY COLUMN'}. Error: {e}"
                    )


        return input_df

    def _update_outputs(self, perf_metrics: Dict[str, Any], row_by_row_metrics_df: pd.DataFrame) -> None:
        """
        Updates the output DataFrame, the metrics DataFrame and the performance metrics file in the model evaluation folder.
        :param perf_metrics: Dictionary containing the global metrics, used for the metrics DataFrame and the performance of the model evaluation.
        :param row_by_row_metrics_df: DataFrame with the metrics computed for each row.
        """
        if self.has_output_dataset or self.has_model_evaluation_store:
            self.output_df = pd.concat([self.output_df, row_by_row_metrics_df], axis=1)

        if self.has_model_evaluation_store:
            self.perf["metrics"].update(perf_metrics)

        if self.has_metrics_dataset:
            self.metrics_df = pd.concat([self.metrics_df, pd.DataFrame.from_dict([perf_metrics])], axis=1)  # NOSONAR : Wrap perf_metrics in a list to create a single-row DataFrame; without the list, from_dict expects dict values to be array-like.

    def has_bert_score(self) -> bool:
        # Other metric util files have their has_X method within the utils file itself, but we can't do that for
        # bert-score. The bert-score dependency (and the torch secondary dependency it brings) are optional for the
        # agent evaluation recipe (you only need them to compute the BERTScore metric), so we only want to load them if
        # needed. If this function was in the bert_score_utils file, we could cause import errors trying to use it in
        # code envs without bert-score installed.
        return "bertScore" in self.recipe_desc.metrics

    @abstractmethod
    def _get_perf_file_name(self) -> str:
        pass

    @abstractmethod
    def _run_sanity_checks(self, input_df: pd.DataFrame) -> None:
        pass

    @abstractmethod
    def _craft_metric_inputs(self, input_df: pd.DataFrame) -> GenAIMetricInput:
        pass

    @abstractmethod
    def _compute_and_update_metrics(self, input_df: pd.DataFrame, metric_inputs: GenAIMetricInput) -> None:
        pass
