import base64
import logging
from urllib.error import HTTPError

from dataiku.llm.evaluation.utils.completion_trace_handler import CompletionTraceHandler
import numpy as np
import pandas as pd

from langchain_core.callbacks import BaseCallbackHandler
from ragas import evaluate, RunConfig
from ragas.metrics.base import Metric
from ragas.metrics import (AnswerRelevancy, Faithfulness, ContextRecall, ContextPrecision, AnswerCorrectness, AnswerSimilarity,
    MultiModalRelevance, MultiModalFaithfulness,
    AgentGoalAccuracyWithReference, AgentGoalAccuracyWithoutReference)
from typing import List, Optional, Tuple, Dict, Iterable, Union
from datasets import Dataset
from datasets.features.features import Sequence
from ragas.dataset_schema import  SingleTurnSample, MultiTurnSample, EvaluationDataset
from ragas.messages import HumanMessage,AIMessage,ToolMessage,ToolCall

from dataiku import Folder
from dataiku.langchain.dku_embeddings import DKUEmbeddings, TraceableDKUEmbeddings
from dataiku.llm.evaluation.exceptions import RagasException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import get_llm_args
from dataiku.llm.evaluation.utils.ragas.ragas_compatible_llm import RagasCompatibleLLM
from dataiku.llm.tracing import SpanBuilder
from dataiku.llm.types import CompletionSettings

# This map values could also be lambdas. We can't store instances directly as we
# require new instances at each metric computation run to prevent race conditions.
RAGAS_LLM_METRICS_GENERATOR_MAP = {
    "answerRelevancy": AnswerRelevancy,
    "answerCorrectness": AnswerCorrectness,
    "answerSimilarity": AnswerSimilarity,
    "faithfulness": Faithfulness,
    "contextRecall": ContextRecall,
    "contextPrecision": ContextPrecision,
    "multimodalRelevancy": MultiModalRelevance,
    "multimodalFaithfulness": MultiModalFaithfulness,
}
RAGAS_AGENT_METRICS = {
    "agentGoalAccuracyWithReference",
    "agentGoalAccuracyWithoutReference",
}
RAGAS_LLM_OUTPUT_METRICS_MAP = {
    "answer_relevancy": "answerRelevancy",
    "semantic_similarity": "answerSimilarity",
    "faithfulness": "faithfulness",
    "context_recall": "contextRecall",
    "context_precision": "contextPrecision",
    "answer_correctness": "answerCorrectness",
    "answer_similarity": "answerSimilarity",
    "relevance_rate": "multimodalRelevancy",
    "faithful_rate": "multimodalFaithfulness",
}

RAGAS_METRICS_WITH_GROUND_TRUTH = {"contextPrecision", "contextRecall", "answerCorrectness", "answerSimilarity", "agentGoalAccuracyWithReference"}
RAGAS_METRICS_WITH_CONTEXT = {"faithfulness", "answerRelevancy", "contextPrecision", "contextRecall"}
RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT = {"multimodalFaithfulness", "multimodalRelevancy"}
RAGAS_METRICS_WITH_OUTPUT = {
    "faithfulness",
    "answerRelevancy",
    "answerSimilarity",
    "answerCorrectness",
    "multimodalFaithfulness",
    "multimodalRelevancy",
    "agentGoalAccuracyWithReference",
    "agentGoalAccuracyWithoutReference"}
RAGAS_METRICS_WITH_ACTUAL_TOOL_CALLS = {"agentGoalAccuracyWithReference", "agentGoalAccuracyWithoutReference"}


logger = logging.getLogger(__name__)


class RagasMetricsComputer(object):
    llm: RagasCompatibleLLM
    embeddings_model_id: str
    max_workers: int

    def __init__(self, completion_llm_id: str, completion_settings: CompletionSettings, embedding_llm_id: str, max_workers: int, fail_on_row_level_errors: bool,
                 can_compute_multimodal_metrics: bool = False):
        self.llm = RagasCompatibleLLM(llm_id=completion_llm_id, **get_llm_args(completion_settings))
        self.embeddings_model_id = embedding_llm_id
        self.max_workers = max_workers
        self.fail_on_row_level_errors = fail_on_row_level_errors
        self.can_compute_multimodal_metrics = can_compute_multimodal_metrics
        logger.info(f"Ragas metrics will be computed with completion LLM {completion_llm_id} and embedding LLM {embedding_llm_id}, "
                    f"with a max of {max_workers} workers")

    def compute_llm_metrics(self, metric_inputs: GenAIMetricInput, metrics: List[str], trace: Optional[SpanBuilder] = None) -> Tuple[Dict[str, Optional[float]], pd.DataFrame]:
        if not self.can_compute_multimodal_metrics:
            metrics = [metric for metric in metrics if metric not in RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT]
        try:
            ragas_metrics_keys_to_compute = set(metrics) & RAGAS_LLM_METRICS_GENERATOR_MAP.keys()
            with_ground_truth_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_GROUND_TRUTH
            with_context_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_CONTEXT
            with_multimodal_context_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT
            with_output_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_OUTPUT

            if with_context_metrics and with_multimodal_context_metrics:
                raise RagasException("Can't compute multimodal and textual context metrics (faithfulness, relevancy, precision, recall) simultaneously. Select only the metrics suitable to your context type.")

            # empty values will crash. Avoid them.
            # Note that, in theory, we could compute some metrics even if some values are empty (e.g. answer relevancy don't need context)
            # We don't bother, and rule out the entire row instead
            ragas_metric_to_compute = get_ragas_llm_metrics(ragas_metrics_keys_to_compute)
            metric_names = [metric.name for metric in ragas_metric_to_compute]
            metric_names_by_column_role = {GenAIMetricInputRole.INPUT: metric_names}
            if with_output_metrics:
                metric_names_by_column_role[GenAIMetricInputRole.OUTPUT] = [metric.name for metric in get_ragas_llm_metrics(with_output_metrics)]
            if with_ground_truth_metrics:
                metric_names_by_column_role[GenAIMetricInputRole.GROUND_TRUTH] = [metric.name for metric in get_ragas_llm_metrics(with_ground_truth_metrics)]
            if with_context_metrics:
                metric_names_by_column_role[GenAIMetricInputRole.CONTEXT] = [metric.name for metric in get_ragas_llm_metrics(with_context_metrics)]

            filtered_columns = failure_utils.filter_null_rows_by_metric_input_role(
                metric_inputs, metric_names_by_column_role, metric_names
            )
            initial_index = filtered_columns.input.index  # Keep index for output

            ragas_input_df = pd.DataFrame()
            ragas_input_df['user_input'] = filtered_columns.input

            if with_output_metrics:
                ragas_input_df['response'] = filtered_columns.output
            if with_ground_truth_metrics:
                ground_truth_type = type(filtered_columns.ground_truth.iloc[0])
                if ground_truth_type == list or ground_truth_type == np.ndarray:
                    raise RagasException(f"Ragas metrics do not support multiple ground truths: the ground truth column needs to be of type string and not {str(ground_truth_type)}")
                ragas_input_df['reference'] = filtered_columns.ground_truth
            if with_context_metrics:
                if not isinstance(filtered_columns.context.iloc[0], list):
                    if not isinstance(filtered_columns.context.iloc[0], str):
                        raise RagasException(f"The context column '{metric_inputs.context.name}' must be of type string (or array of strings), got: {filtered_columns.context[0]}")
                    ragas_input_df['retrieved_contexts'] = filtered_columns.context.apply(lambda x: [x])
                else:
                    ragas_input_df['retrieved_contexts'] = filtered_columns.context
            elif with_multimodal_context_metrics:
                # No checks needed as it should be on the output of prompt recipe
                ragas_input_df['retrieved_contexts'] = filtered_columns.context.apply(read_multimodal_context_from_prompt_recipe)

            logger.info(f"The following RAGAS metric will be computed: {metric_names}")

            ragas_dataset = Dataset.from_pandas(ragas_input_df)
            ragas_column_mapping = {
                "user_input": metric_inputs.input.name if metric_inputs.input is not None else None,
                "response": metric_inputs.output.name if metric_inputs.output is not None else None,
                "reference": metric_inputs.ground_truth.name if metric_inputs.ground_truth is not None else None,
                "retrieved_contexts": metric_inputs.context.name if metric_inputs.context is not None else None
            }
            validate_column_dtypes(ragas_dataset, ragas_column_mapping, any(with_multimodal_context_metrics))

            # Langchain does not support callbacks on embedding, so we rely on our own mechanism by passing the trace to the `TraceableDKUEmbeddings` ctor.
            embeddings = DKUEmbeddings(llm_id=self.embeddings_model_id) if trace is None else TraceableDKUEmbeddings(trace, llm_id=self.embeddings_model_id)
            # For traces on completions, however, we can rely on Langchain callbacks
            callbacks: List[BaseCallbackHandler] = [] if trace is None else [CompletionTraceHandler(trace)]
            ret = evaluate(
                ragas_dataset,
                metrics=ragas_metric_to_compute,
                llm=self.llm,
                embeddings=embeddings,
                raise_exceptions=self.fail_on_row_level_errors,
                run_config=RunConfig(max_workers=self.max_workers, max_retries=3, timeout=600, max_wait=1800),
                callbacks=callbacks,
            )

            # Try to warn the user, but best effort, openAI error messages are not very stable
            if any('"Unsupported value: \'temperature\'' in e for e in self.llm.errors):
                raise RagasException('temperature is not supported with this model. Clear the temperature from the Advanced tab')
            self.llm.errors = []

            logger.info(f"Global Ragas metrics result: {str(ret)}")
            ret_dict = ret._repr_dict
            ret_dict_keys = ret_dict.keys()
            return ({RAGAS_LLM_OUTPUT_METRICS_MAP[metric]: ret_dict[metric] if not np.isnan(ret_dict[metric]) else None for metric in ret_dict_keys},
                    ret.to_pandas().set_index(initial_index)[ret_dict_keys].rename(columns=RAGAS_LLM_OUTPUT_METRICS_MAP))

        except Exception as e:
            raise RagasException(f"An error happened during the computation of RAGAS metrics: {str(e)}", e)

    # Only computes Agent Goal Accuracy, so, not much effort in mutualising stuff with compute_llm_metrics
    def compute_agent_metrics(self, metric_inputs: GenAIMetricInput, metrics: List[str], has_ground_truth: bool, trace: Optional[SpanBuilder] = None) -> Tuple[Dict[str, Optional[float]], pd.DataFrame]:
        try:
            if "agentGoalAccuracyWithReference" in metrics and "agentGoalAccuracyWithoutReference" in metrics:
                raise RagasException("Ragas metrics do not support computing both Agent Goal Accuracy With Reference and Agent Goal Accuracy Without Reference at once. Only one can be selected.")
            with_reference = 'agentGoalAccuracyWithReference' in metrics
            if with_reference and not has_ground_truth:
                raise RagasException("Agent Goal Accuracy With Reference requires a ground truth to use as a reference.")

            if with_reference:
                metric_to_compute = AgentGoalAccuracyWithReference()
                output_metric_name = 'agentGoalAccuracyWithReference'
            else:
                metric_to_compute = AgentGoalAccuracyWithoutReference()
                output_metric_name = 'agentGoalAccuracyWithoutReference'
            # empty values will crash. Avoid them.
            # Note that, in theory, we could compute some metrics even if some values are empty (e.g. answer relevancy don't need context)
            # We don't bother, and rule out the entire row instead
            metric_names = [metric_to_compute.name]  # only one metric possible
            metric_names_by_column_role = {
                GenAIMetricInputRole.INPUT: metric_names,
                GenAIMetricInputRole.OUTPUT: metric_names,
                GenAIMetricInputRole.ACTUAL_TOOL_CALLS: metric_names
            }
            if with_reference:
                metric_names_by_column_role[GenAIMetricInputRole.GROUND_TRUTH] = metric_names

            filtered_columns = failure_utils.filter_null_rows_by_metric_input_role(
                metric_inputs, metric_names_by_column_role, metric_names
            )
            initial_index = filtered_columns.input.index  # Keep index for output

            if with_reference:
                ground_truth_type = type(filtered_columns.ground_truth.iloc[0])
                if ground_truth_type == list or ground_truth_type == np.ndarray:
                    raise RagasException(f"Ragas metrics do not support multiple ground truths: the ground truth column needs to be of type string and not {str(ground_truth_type)}")

            multiturn_samples = []
            have_name_only_tools = False
            for i, tool_calls in filtered_columns.actual_tool_calls.items():
                agent_input = filtered_columns.input[i]
                agent_output = filtered_columns.output[i]
                multiturn_sample_messages = [HumanMessage(content=agent_input)]
                for tool_call in tool_calls:
                    if isinstance(tool_call, dict): # probably from prompt recipe
                        tool_input = tool_call.get('inputContent') or {}  # ragas requires a dict
                        tool_output = tool_call.get('outputContent')
                        tool_name = tool_call.get('toolName')
                        multiturn_sample_messages.append(AIMessage(content = 'I need to call these tools', tool_calls=[
                            ToolCall(name=tool_name, args=tool_input)]))
                        multiturn_sample_messages.append(ToolMessage(content=str(tool_output))) # ToolMessage must be after an AIMessage
                    else: # user supplied list
                        have_name_only_tools = True
                        tool_input = {}
                        tool_name = str(tool_call)
                        multiturn_sample_messages.append(AIMessage(content = 'I need to call these tools', tool_calls=[
                            ToolCall(name=tool_name, args=tool_input)]))
                multiturn_sample_messages.append(AIMessage(content=agent_output))
                if with_reference:
                    multiturn_samples.append(MultiTurnSample(user_input=multiturn_sample_messages, reference=filtered_columns.ground_truth[i]))
                else:
                    multiturn_samples.append(MultiTurnSample(user_input=multiturn_sample_messages))

            if have_name_only_tools:
                logger.warning("Using Actual tool calls without input/output. Agent Goal Accuracy results might be worse than expected.")

            tool_dataset = EvaluationDataset(samples=multiturn_samples)

            # Langchain does not support callbacks on embedding, so we rely on our own mechanism by passing the trace to the `TraceableDKUEmbeddings` ctor.
            embeddings = DKUEmbeddings(llm_id=self.embeddings_model_id) if trace is None else TraceableDKUEmbeddings(trace, llm_id=self.embeddings_model_id)
            # For traces on completions, however, we can rely on Langchain callbacks
            callbacks: List[BaseCallbackHandler] = [] if trace is None else [CompletionTraceHandler(trace)]

            ret = evaluate(
                tool_dataset,
                metrics=[metric_to_compute],
                llm=self.llm,
                embeddings=embeddings,
                raise_exceptions=self.fail_on_row_level_errors,
                run_config=RunConfig(max_workers=self.max_workers, max_retries=3, timeout=600, max_wait=1800),
                callbacks=callbacks,
            )

            # Try to warn the user, but best effort, openAI error messages are not very stable
            if any('"Unsupported value: \'temperature\'' in e for e in self.llm.errors):
                raise RagasException('temperature is not supported with this model. Clear the temperature from the Advanced tab')
            self.llm.errors = []

            logger.info(f"Global Ragas metrics result: {str(ret)}")
            ret_dict = ret._repr_dict
            return ({output_metric_name: ret_dict['agent_goal_accuracy'] if not np.isnan(ret_dict['agent_goal_accuracy']) else None},
                    ret.to_pandas().set_index(initial_index)[['agent_goal_accuracy']].rename(columns={'agent_goal_accuracy': output_metric_name}))

        except Exception as e:
            raise RagasException(f"An error happened during the computation of RAGAS metrics: {str(e)}", e)


def check_use_ragas_metrics(metrics: List[str], has_ground_truth: bool, has_context: bool, has_actual_tool_calls: bool, completion_llm_id: str, embedding_llm_id: str,
                            can_compute_multimodal_metrics: bool) -> None:
    """
    Some ragas metrics require the ground truth, context etc.
    This method asserts that the metrics chosen by the user are consistent with their dataset.
    This check is on all metrics, agent and llm. We could in theory separate them, but we might as well fail early.
    """
    ground_truth_based_selected_metrics = list(set(metrics) & RAGAS_METRICS_WITH_GROUND_TRUTH)
    if ground_truth_based_selected_metrics and not has_ground_truth:
        raise RagasException(f"The following metrics require a ground truth column: {str(ground_truth_based_selected_metrics)}")

    context_based_selected_metrics = list(set(metrics) & RAGAS_METRICS_WITH_CONTEXT)
    if context_based_selected_metrics and not has_context:
        raise RagasException(f"The following metrics require a context column: {str(context_based_selected_metrics)}")

    tool_call_based_selected_metrics = list(set(metrics) & RAGAS_METRICS_WITH_ACTUAL_TOOL_CALLS)
    if tool_call_based_selected_metrics and not has_actual_tool_calls:
        raise RagasException(f"The following metrics require an actual tool calls column: {str(tool_call_based_selected_metrics)}")

    if can_compute_multimodal_metrics:
        multimodal_context_based_selected_metrics = list(set(metrics) & RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT)
        if multimodal_context_based_selected_metrics and not has_context:
            raise RagasException(f"The following metrics require a multimodal context column: {str(multimodal_context_based_selected_metrics)}")


    if not completion_llm_id or not embedding_llm_id:
        missing_metrics = list(set(metrics) & RAGAS_AGENT_METRICS.union(RAGAS_LLM_METRICS_GENERATOR_MAP.keys()))
        raise RagasException(f"You need to select both an Embedding LLM and a Completion LLM to compute: {missing_metrics}. Please verify your recipe configuration.")


def has_context_based_metrics(metrics: List[str]) -> bool:
    return any(metric in RAGAS_METRICS_WITH_CONTEXT for metric in metrics)


def has_ragas_llm_metrics(metrics: List[str]) -> bool:
    """
    Checks if there is at least one ragas llm metric in the given metric list
    """
    return any(metric in RAGAS_LLM_METRICS_GENERATOR_MAP for metric in metrics)


def has_ragas_agent_metrics(metrics: List[str]) -> bool:
    return any(metric in RAGAS_AGENT_METRICS for metric in metrics)


def get_ragas_llm_metrics(metrics: Iterable[str]) -> List[Metric]:
    ret = []
    for metric in metrics:
        if metric in RAGAS_LLM_METRICS_GENERATOR_MAP:
            metric_instance = RAGAS_LLM_METRICS_GENERATOR_MAP[metric]()
            ret.append(metric_instance)
    return ret


def create_empty_ragas_llm_metrics(metric_inputs: GenAIMetricInput, metrics: Iterable[str]) -> Tuple[Dict[str, None], pd.DataFrame]:
    metric_names = [metric.name for metric in get_ragas_llm_metrics(metrics)]
    return failure_utils.create_empty_metrics(metric_inputs, metric_names)


def create_empty_ragas_agent_metrics(metric_inputs: GenAIMetricInput, metrics: Iterable[str]) -> Tuple[Dict[str, None], pd.DataFrame]:
    metric_names = set(RAGAS_AGENT_METRICS) & set(metrics)
    return failure_utils.create_empty_metrics(metric_inputs, metric_names)


def validate_column_dtypes(ds: Dataset, column_mapping: dict, expects_multimodal_context: bool):
    for column_names in ["user_input", "response", "reference"]:
        if column_names in ds.features:
            column_dtype = ds.features[column_names].dtype
            if column_dtype != "string":
                raise ValueError(
                    f'Dataset feature "{column_mapping[column_names]}" should be of type string, got {column_dtype}'
                )

    if "retrieved_contexts" in ds.features:
        if expects_multimodal_context:
            if not any(ds["retrieved_contexts"]):
                raise ValueError(
                    f"Can't get multimodal context: unable to read images from Dataset feature \"{column_mapping['retrieved_contexts']}\". It should contain paths to images from a managed folder.")
        else:
            if not (
                    isinstance(ds.features["retrieved_contexts"], Sequence)
                    and hasattr(ds.features["retrieved_contexts"], "feature")
                    and ds.features["retrieved_contexts"].feature.dtype == "string"
            ):
                raise ValueError(
                    f"Can't get textual context from Dataset feature \"{column_mapping['retrieved_contexts']}\", it should be of type string (or array of strings).")


def read_multimodal_context_from_prompt_recipe(sources: List[Union[Dict, str]]) -> List[str]:
    contexts = []
    image_cache = {}
    try:
        for s in sources:
            # For text-extracted documents the legacy and updated format both contain the string after parsing
            if isinstance(s, str):
                contexts.append(s)
            else:
                if s.get("excerpt") is None:
                    # Handle updated format for image-extracted documents
                    source_dict = s
                    image_key = "imageRefs"
                    folder_key = "folderId"
                else:
                    # Handle legacy format for image-extracted documents
                    source_dict = s.get("excerpt")
                    image_key = "images"
                    folder_key = "fullFolderId"
                for image in source_dict.get(image_key):
                    project_key, lookup = image[folder_key].split(".", 1)
                    image_path = image["path"]
                    if image_path not in image_cache:
                        try:
                            folder = Folder(lookup, project_key, ignore_flow=True)
                            with folder.get_download_stream(image_path) as image_file:
                                image_cache[image_path] = base64.b64encode(image_file.read()).decode("utf8")
                        except HTTPError as e:
                            logger.warning(f"Error when retrieving file {image_path} in folder {image[folder_key]}: {str(e)}")
                            continue
                    contexts.append(image_cache[image_path])
        return [c for c in contexts if c]
    except Exception as e:
        logger.warning(f"Error when trying to parse multimodal sources on the row {sources}: {str(e)}")
        return []
