import logging

import numpy as np
import pandas as pd
from typing import Collection, List, Tuple, Dict, Iterable

from dataiku.doctor.diagnostics import diagnostics
from dataiku.llm.evaluation.exceptions import LLMEvalException, MissingColumnException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole

logger = logging.getLogger(__name__)


def filter_null_rows(
    metric_input: GenAIMetricInput, metric_names: List[str], input_roles: Collection[GenAIMetricInputRole]
) -> GenAIMetricInput:
    """
    A slightly less powerful but easier to use version of filter_null_rows_by_metric_input_role, for situations in which
    all metrics require the same input roles. It creates the metric_names_by_input_role input for you, simplifying the
    function signature. It should not be used in situations (like the ragas_utils files) where different metrics require
    different input roles.
    """
    metric_names_by_input_role = {input_role: metric_names for input_role in input_roles}
    return filter_null_rows_by_metric_input_role(metric_input, metric_names_by_input_role, metric_names)


def filter_null_rows_by_metric_input_role(
    columns: GenAIMetricInput,
    metric_names_by_input_role: Dict[GenAIMetricInputRole, List[str]],
    metric_names: List[str],
) -> GenAIMetricInput:
    if not any(metric_names_by_input_role):
        return columns

    rows_without_null = None
    column_names = []
    for input_role in metric_names_by_input_role.keys():
        column_values = columns.get(input_role)
        if column_values is None or column_values.isnull().all():
            raise MissingColumnException(
                f"Error computing {metric_names_by_input_role[input_role]}: those metrics require column '{input_role.value}', which is missing or empty."
            )

        column_names.append(input_role.value)

        if rows_without_null is None:
            rows_without_null = column_values.notnull()
        else:
            rows_without_null = rows_without_null & column_values.notnull()

    rows_with_null: pd.Series = ~rows_without_null  # type: ignore - rows_without_null can't be None here
    if rows_with_null.all():
        raise LLMEvalException(
            f"Error computing {metric_names} all rows have at least one empty value on the required columns {column_names}. Can't compute metrics."
        )

    null_indexes = rows_with_null[rows_with_null].index
    if null_indexes.any():
        max_print = 20
        faulty_rows = null_indexes[:max_print].to_list()
        roles = [role.value for role in metric_names_by_input_role.keys()]
        error_message = (
            f"Warning computing {metric_names}: some rows are missing some values for one of the following columns: "
            f"{roles}. Dismissing them from computation. Faulty rows: {faulty_rows}"
        )
        if len(null_indexes) > max_print:
            error_message += " (and %s other rows)" % (len(null_indexes) - max_print)
        logger.warning(error_message)
        diagnostics.add_or_update(diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR, error_message)

    result = GenAIMetricInput.from_series(
        columns.input[rows_without_null] if columns.input is not None else None,
        columns.output[rows_without_null] if columns.output is not None else None,
        columns.ground_truth[rows_without_null] if columns.ground_truth is not None else None,
        columns.context[rows_without_null] if columns.context is not None else None,
        columns.actual_tool_calls[rows_without_null] if columns.actual_tool_calls is not None else None,
        columns.reference_tool_calls[rows_without_null] if columns.reference_tool_calls is not None else None,
    )
    return result


def raise_or_continue(e, metric_name, fail_on_errors, input_format):
    if e.original_exception is not None and isinstance(e.original_exception, MissingColumnException):
        if input_format == 'PROMPT_RECIPE':
            e = LLMEvalException(
                e.message +
                " Make sure \"Raw query output mode\" and \"Raw response output mode\" in your Prompt recipe are not set to \"None\"."
                " If computing context-based metrics, make sure that your Prompt recipe uses a Retrieval-augmented LLM with "
                "\"Source output format\" set to \"Separated\".",
                e)
        elif input_format == 'DATAIKU_ANSWERS':
            e = LLMEvalException(
                e.message +
                " Make sure the \"Retrieval Method\" is set to 'Use knowledge bank retrieval' in Dataiku Answers' settings.",
                e)

    if fail_on_errors:
        raise e
    else:
        explicit_error = "Error computing %s metric: %s." % (metric_name, str(e))
        logger.error(explicit_error + " Stop on errors is not enabled, carrying on with other metrics.")
        diagnostics.add_or_update(
            diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
            explicit_error
        )


def warn(explicit_warning, raise_diagnostic=True):
    logger.warning(explicit_warning)
    if raise_diagnostic:
        diagnostics.add_or_update(
            diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
            explicit_warning
        )


def create_empty_metrics(columns: GenAIMetricInput, metrics: Iterable[str]) -> Tuple[Dict[str, None], pd.DataFrame]:
    empty_perf = {metric: None for metric in metrics}
    empty_row_by_row = pd.DataFrame({metric: pd.Series(None, columns.input.index, dtype=np.dtype(np.float64)) for metric in metrics})
    return empty_perf, empty_row_by_row
