import logging
import inspect
from typing import Optional

import pandas as pd

from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.prediction.custom_scoring import get_custom_score_or_None_and_error
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from dataiku.langchain.dku_llm import DKULLM
from dataiku.llm.evaluation.custom_llm_evaluation_metric import CustomLLMEvaluationMetric
from dataiku.doctor.diagnostics import diagnostics
from dataiku.llm.evaluation.exceptions import CustomMetricException
from dataiku.llm.evaluation.llm_metrics_input import LLMMetricInput
from dataiku.llm.evaluation.llm_eval_recipe_desc import LLMEvalRecipeDesc
from dataiku.llm.evaluation.utils.metrics_utils import get_llm_args
from dataiku.llm.types import CompletionSettings

logger = logging.getLogger(__name__)


class RecipeParamsForCustomMetrics:
    """
    This class represent the parameters of the LLM Evaluation recipe that are accessible to custom metric functions
    """
    input_column_name: str
    output_column_name: str
    ground_truth_column_name: str
    context_column_name: str
    completion_llm: Optional[DKULLM] = None
    embedding_llm: Optional[DKUEmbeddings] = None

    def __init__(self, recipe_desc: LLMEvalRecipeDesc, completion_settings: CompletionSettings):
        self.input_column_name = recipe_desc.input_column_name
        self.output_column_name = recipe_desc.output_column_name
        self.ground_truth_column_name = recipe_desc.ground_truth_column_name
        self.context_column_name = recipe_desc.context_column_name
        try:
            self.completion_llm = DKULLM(llm_id=recipe_desc.completion_llm_id, **get_llm_args(completion_settings))
        except Exception as e:
            logger.warning("Cannot instantiate LLM: %s. Completion LLM won't be available. Reason %s" % (recipe_desc.completion_llm_id, str(e)))
        try:
            self.embedding_llm = DKUEmbeddings(llm_id=recipe_desc.embedding_llm_id)
        except Exception as e:
            logger.warning("Cannot instantiate LLM: %s. Embedding LLM won't be available. Reason %s" % (recipe_desc.completion_llm_id, str(e)))


def compute_custom_metric(custom_metric: CustomLLMEvaluationMetric, input_df: pd.DataFrame, interpreted_columns: LLMMetricInput, fail_on_errors: bool, recipe_params_for_custom_metrics: RecipeParamsForCustomMetrics) -> dict:
    try:
        evaluate_function = _get_custom_evaluate_func(custom_metric.metric_code)
    except SyntaxError as s:
        error = "%s on line %s" % (s.msg, s.lineno)
        explicit_error = "Parsing of custom metric evaluation function '%s' failed: %s. No score written for that metric. See logs for details." % (custom_metric.name, error)
        return _fail(explicit_error, error, custom_metric, fail_on_errors)
    except Exception as e:
        explicit_error = "Parsing of custom metric evaluation function '%s' failed: %s. No score written for that metric. See logs for details." % (custom_metric.name, safe_unicode_str(e))
        return _fail(explicit_error, e, custom_metric, fail_on_errors)
    try:
        custom_metric_result = _execute_parsed_custom_metric_function(custom_metric.name, evaluate_function, input_df, interpreted_columns, recipe_params_for_custom_metrics)
        custom_metric_result['metric'] = custom_metric.get_raw()
    except Exception as e:
        explicit_error = "Execution of custom metric evaluation function '%s' failed: %s. No score written for that metric. See logs for details." % (custom_metric.name, safe_unicode_str(e))
        return _fail(explicit_error, e, custom_metric, fail_on_errors)
    return custom_metric_result


def _fail(explicit_error, exception, custom_metric, fail_on_errors):
    if fail_on_errors:
        raise CustomMetricException(explicit_error)
    logger.error(explicit_error, exc_info=True)
    diagnostics.add_or_update(
        diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
        explicit_error
    )
    custom_metric_result = {
        'metric': custom_metric.get_raw(),
        'didSucceed': False,
        'error': safe_unicode_str(exception)
    }
    return custom_metric_result


def _get_custom_evaluate_func(code):
    exec_dic = {}
    exec(code, exec_dic, exec_dic)  # python2_friendly_exec
    if "evaluate" not in exec_dic:
        raise ValueError("Custom metric evaluation function not defined")
    evaluate_function = exec_dic["evaluate"]
    sig = inspect.signature(evaluate_function).parameters
    args = list(sig.keys())
    if len(args) < 3 or args[0] != 'input_df' or args[1] != 'recipe_params' or args[2] != 'interpreted_columns':
        raise ValueError("Custom metric evaluation function expects at least three arguments, (input_df, recipe_params, interpreted_columns). Got: (%s)" % ", ".join(args))

    return evaluate_function


def _execute_parsed_custom_metric_function(
        custom_metric_name: str,
        evaluate_function,
        input_df: pd.DataFrame,
        interpreted_columns: LLMMetricInput,
        recipe_params_for_custom_metrics: RecipeParamsForCustomMetrics):
    custom_metric_result = {}
    try:
        result = evaluate_function(input_df, recipe_params_for_custom_metrics, interpreted_columns)
    except Exception as e:
        raise ValueError("Custom metric evaluation function '%s' failed: %s." % (custom_metric_name, safe_unicode_str(e)))

    try:
        score = result[0]
    except (TypeError, IndexError):  # result is not iterable, we expect it is just a float
        score = result
    corrected_score, error = get_custom_score_or_None_and_error(score, allow_naninf=False)
    if error:
        logger.warning("Custom metric evaluation function '%s', global metric value %s is invalid: %s" % (custom_metric_name, score, str(error)))
    custom_metric_result['value'] = corrected_score

    try:
        row_by_row = list(result[1])
        assert len(row_by_row) == input_df.shape[0]
        error_in_rows = False
        for i, row_score in enumerate(row_by_row):
            corrected_score, error = get_custom_score_or_None_and_error(row_score, allow_naninf=False)
            if error:
                logger.warning("Custom metric evaluation function '%s', row %s: %s" % (custom_metric_name, i, str(error)))
                error_in_rows = True
                row_by_row[i] = corrected_score
        if error_in_rows:
            diagnostics.add_or_update(
                diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
                "Custom metric evaluation function '%s' did produce some row(s) with non finite values. Ignoring them. See logs for details." % custom_metric_name
            )
        custom_metric_result['rowByRowValues'] = row_by_row
        logger.info("Custom metric evaluation function '%s' produced a global metric value and row-by-row values" % custom_metric_name)
    except (TypeError, IndexError, AssertionError, ValueError):  # result is like (float, something_not_valid)
        custom_metric_result['rowByRowValues'] = None
        explicit_error = "Custom metric evaluation function '%s' did not produce finite valid row-by-row values, assuming only a global metric value"\
                         % custom_metric_name
        logger.warning(explicit_error)
        # no warning in diagnostic, this is a "normal/expected" issue.

    custom_metric_result['didSucceed'] = True
    return custom_metric_result
