import logging
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import pandas as pd

import dataiku
from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.prediction.custom_scoring import get_custom_score_or_None_and_error
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from dataiku.langchain.dku_llm import DKULLM
from dataiku.llm.evaluation.custom_evaluation_trait import CustomEvaluationTrait
from dataiku.llm.evaluation.genai_custom_evaluation_metric  import GenAiCustomEvaluationMetric
from dataiku.doctor.diagnostics import diagnostics
from dataiku.llm.evaluation.exceptions import CustomMetricException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataiku.llm.evaluation.utils.metrics_utils import get_llm_args, mean_or_none
from dataiku.llm.types import CompletionSettings

logger = logging.getLogger(__name__)


class RecipeParamsForCustomMetrics:
    """
    This class represent the parameters of the LLM Evaluation recipe that are accessible to custom metric functions
    """
    input_column_name: str
    output_column_name: str
    ground_truth_column_name: str
    context_column_name: str
    reference_tool_calls_column_name: str
    actual_tool_calls_column_name: str
    completion_llm: Optional[DKULLM]
    embedding_llm: Optional[DKUEmbeddings]

    def __init__(self, recipe_desc: GenAIEvalRecipeDesc, completion_settings: CompletionSettings):
        self.input_column_name = recipe_desc.input_column_name
        self.output_column_name = recipe_desc.output_column_name
        self.ground_truth_column_name = recipe_desc.ground_truth_column_name
        self.context_column_name = recipe_desc.context_column_name
        self.reference_tool_calls_column_name = recipe_desc.reference_tool_calls_column_name
        self.actual_tool_calls_column_name = recipe_desc.actual_tool_calls_column_name
        try:
            self.completion_llm = DKULLM(llm_id=recipe_desc.completion_llm_id, **get_llm_args(completion_settings))
        except Exception as e:
            self.completion_llm = None
            logger.warning("Cannot instantiate LLM: %s. Completion LLM won't be available. Reason %s" % (recipe_desc.completion_llm_id, str(e)))
        try:
            self.embedding_llm = DKUEmbeddings(llm_id=recipe_desc.embedding_llm_id)
        except Exception as e:
            self.embedding_llm = None
            logger.warning("Cannot instantiate LLM: %s. Embedding LLM won't be available. Reason %s" % (recipe_desc.completion_llm_id, str(e)))


def compute_custom_metric(custom_metric: GenAiCustomEvaluationMetric, input_df: pd.DataFrame, metric_inputs: GenAIMetricInput, fail_on_errors: bool, recipe_params_for_custom_metrics: RecipeParamsForCustomMetrics) -> Dict[str, Any]:
    try:
        evaluate_function = _get_custom_evaluate_func(custom_metric.metric_code)
    except SyntaxError as s:
        error = "%s on line %s" % (s.msg, s.lineno)
        explicit_error = "Parsing of custom metric evaluation function '%s' failed: %s. No score written for that metric. See logs for details." % (custom_metric.name, error)
        return _fail(explicit_error, error, custom_metric, fail_on_errors)
    except Exception as e:
        explicit_error = "Parsing of custom metric evaluation function '%s' failed: %s. No score written for that metric. See logs for details." % (custom_metric.name, safe_unicode_str(e))
        return _fail(explicit_error, e, custom_metric, fail_on_errors)
    try:
        custom_metric_result = _execute_parsed_custom_metric_function(custom_metric.name, evaluate_function, input_df, metric_inputs, recipe_params_for_custom_metrics)
        custom_metric_result['metric'] = custom_metric.get_raw()
    except Exception as e:
        explicit_error = "Execution of custom metric evaluation function '%s' failed: %s. No score written for that metric. See logs for details." % (custom_metric.name, safe_unicode_str(e))
        return _fail(explicit_error, e, custom_metric, fail_on_errors)
    return custom_metric_result


def compute_custom_trait(custom_trait: CustomEvaluationTrait, input_df: pd.DataFrame, metric_inputs: GenAIMetricInput, fail_on_errors: bool, recipe_params_for_custom_metrics: RecipeParamsForCustomMetrics) -> Dict[str, Any]:
    if not recipe_params_for_custom_metrics.completion_llm:
        error = "You need to select a Completion LLM to compute Custom Traits."
        _fail(error, error, custom_trait, fail_on_errors)
    project = dataiku.api_client().get_default_project()
    chat_llm = project.get_llm(recipe_params_for_custom_metrics.completion_llm.llm_id)

    system_prompt = """You are an expert agent answer evaluator. Your task is to check that the given agent answer follows all of the user given assertions.
The assertions are given by the user, and each assertion check properties of the agent answer. The answer's properties are listed last.
Use the answer's properties to validate the assertions.
You must answer with only "true" or "false", with no additional text or explanation:
* **true**: If **all** of the user assertions are true.
* **false**: If any of the user assertions is false, or if you are not sure of the truth of any of the assertions."""
    filtered_lines = []
    for line in custom_trait.prompt.splitlines():
        if not line.lstrip().startswith('#'):
            filtered_lines.append(line)
    trait_prompt = "User assertions:\n" + '\n'.join(filtered_lines)

    row_by_row_values = []
    had_errors = False
    error_message = ""
    batch_size = 10
    for batch_start in range(0, len(metric_inputs.input), batch_size):
        completions = chat_llm.new_completions()
        if recipe_params_for_custom_metrics.completion_llm.temperature is not None:
            completions.settings['temperature'] = recipe_params_for_custom_metrics.completion_llm.temperature
        if recipe_params_for_custom_metrics.completion_llm.top_k is not None:
            completions.settings['top_k'] = recipe_params_for_custom_metrics.completion_llm.top_k
        if recipe_params_for_custom_metrics.completion_llm.top_p is not None:
            completions.settings['top_p'] = recipe_params_for_custom_metrics.completion_llm.top_p
        if recipe_params_for_custom_metrics.completion_llm.max_tokens is not None:
            completions.settings['max_tokens'] = recipe_params_for_custom_metrics.completion_llm.max_tokens

        has_input = metric_inputs.output is not None and metric_inputs.output.any()
        has_ground_truth = metric_inputs.ground_truth is not None and metric_inputs.ground_truth.any()
        has_actual_tool_calls = metric_inputs.actual_tool_calls is not None and metric_inputs.actual_tool_calls.any()
        has_reference_tool_calls = metric_inputs.reference_tool_calls is not None and metric_inputs.reference_tool_calls.any()

        batch_end = min(batch_start + batch_size, len(metric_inputs.input))
        for i in range(batch_start, batch_end):
            answer_properties = f"Input:\n{metric_inputs.input[i]}\n"
            if has_input:
                answer_properties += f"Output: {metric_inputs.output[i]}\n"
            if has_ground_truth:
                answer_properties += f"Ground truth: {metric_inputs.ground_truth[i]}\n"
            if has_actual_tool_calls:
                answer_properties += f"Actual tool calls: {metric_inputs.actual_tool_calls[i]}\n"
            if has_reference_tool_calls:
                answer_properties += f"Reference tool calls: {metric_inputs.reference_tool_calls[i]}\n"
            completion = completions.new_completion()
            completion.with_message(system_prompt, role="system")
            completion.with_message(answer_properties, role="user")
            completion.with_message(trait_prompt, role="user")

        logger.info(f"Batching computation of Trait '{custom_trait.name}' for rows {batch_start} to {batch_end - 1}")

        responses = completions.execute()

        for i in range(len(responses.responses)):
            response = responses.responses[i]
            query = completions.queries[i]
            if response.success:
                logger.debug(f"Response to query {query} is {response.text}") # response.text throws if the query had an error...
                if response.text.strip().lower() == "true":  # be very strict on what we accept. Q: should we look for the first "0" or "1", to catch answers like "The answer is 1" ?
                    row_by_row_values.append("true")
                else:
                    row_by_row_values.append("false")
            else:
                row_by_row_values.append(None)
                try:
                    row_error_message = response.text
                except Exception as e:
                    row_error_message = safe_unicode_str(e)
                logger.error(row_error_message, exc_info=True)
                if not had_errors:
                    had_errors = True
                    error_message = f"On row {i}: {row_error_message}"
                    diagnostics.add_or_update(
                        diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
                        f"Custom prompt evaluation function '{custom_trait.name}' did fail for some rows. First failure message, was: {error_message}"
                    )

    if had_errors and fail_on_errors:
        raise CustomMetricException(error_message)

    result = {
        'metric': custom_trait.get_raw(),
        'value': mean_or_none((pd.Series(row_by_row_values) == 'true').astype(int)),
        'rowByRowStringValues': row_by_row_values,
    }

    if had_errors:
        result['didSucceed'] = False
        result['error'] = error_message
    else:
        result['didSucceed'] = True

    return result


def _fail(explicit_error, exception, custom_metric, fail_on_errors):
    if fail_on_errors:
        raise CustomMetricException(explicit_error)
    logger.error(explicit_error, exc_info=True)
    diagnostics.add_or_update(
        diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
        explicit_error
    )
    custom_metric_result = {
        'metric': custom_metric.get_raw(),
        'didSucceed': False,
        'error': safe_unicode_str(exception)
    }
    return custom_metric_result


def _get_custom_evaluate_func(code: str) -> Callable[[pd.DataFrame], Union[float, Optional[List[float]]]]:
    exec_dic = {}
    exec(code, exec_dic, exec_dic)  # python2_friendly_exec
    if "evaluate" not in exec_dic:
        raise ValueError("Custom metric evaluation function not defined")
    evaluate_function = exec_dic["evaluate"]
    sig = inspect.signature(evaluate_function).parameters
    args = list(sig.keys())
    if len(args) < 3 or args[0] != 'input_df' or args[1] != 'recipe_params' or args[2] != 'interpreted_columns':
        raise ValueError("Custom metric evaluation function expects at least three arguments, (input_df, recipe_params, interpreted_columns). Got: (%s)" % ", ".join(args))

    return evaluate_function


def _execute_parsed_custom_metric_function(
        custom_metric_name: str,
        evaluate_function,
        input_df: pd.DataFrame,
        metric_inputs: GenAIMetricInput,
        recipe_params_for_custom_metrics: RecipeParamsForCustomMetrics) -> Dict[str, Any]:
    custom_metric_result = {}
    try:
        result = evaluate_function(input_df, recipe_params_for_custom_metrics, metric_inputs)
    except Exception as e:
        raise ValueError("Custom metric evaluation function '%s' failed: %s." % (custom_metric_name, safe_unicode_str(e)))

    try:
        score = result[0]
    except (TypeError, IndexError):  # result is not iterable, we expect it is just a float
        score = result
    corrected_score, error = get_custom_score_or_None_and_error(score, allow_naninf=False)
    if error:
        logger.warning("Custom metric evaluation function '%s', global metric value %s is invalid: %s" % (custom_metric_name, score, str(error)))
    custom_metric_result['value'] = corrected_score

    try:
        row_by_row = list(result[1])
        assert len(row_by_row) == input_df.shape[0]
        error_in_rows = False
        for i, row_score in enumerate(row_by_row):
            corrected_score, error = get_custom_score_or_None_and_error(row_score, allow_naninf=False)
            if error:
                logger.warning("Custom metric evaluation function '%s', row %s: %s" % (custom_metric_name, i, str(error)))
                error_in_rows = True
                row_by_row[i] = corrected_score
        if error_in_rows:
            diagnostics.add_or_update(
                diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
                "Custom metric evaluation function '%s' did produce some row(s) with non finite values. Ignoring them. See logs for details." % custom_metric_name
            )
        custom_metric_result['rowByRowValues'] = row_by_row
        logger.info("Custom metric evaluation function '%s' produced a global metric value and row-by-row values" % custom_metric_name)
    except (TypeError, IndexError, AssertionError, ValueError):  # result is like (float, something_not_valid)
        custom_metric_result['rowByRowValues'] = None
        explicit_error = "Custom metric evaluation function '%s' did not produce finite valid row-by-row values, assuming only a global metric value"\
                         % custom_metric_name
        logger.warning(explicit_error)
        # no warning in diagnostic, this is a "normal/expected" issue.

    custom_metric_result['didSucceed'] = True
    return custom_metric_result
