from typing import List, Optional, Union

import pandas as pd
from dataiku.llm.evaluation.exceptions import MetricInputException
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataikuapi.dss.utils import Enum


class GenAIMetricInputRole(Enum):
    INPUT = "Input"
    OUTPUT = "Output"
    GROUND_TRUTH = "Ground truth"
    CONTEXT = "Context"
    ACTUAL_TOOL_CALLS = "Actual Tool Calls"
    REFERENCE_TOOL_CALLS = "Reference Tool Calls"


class GenAIMetricInput(object):
    """
    Holds the columns from the input dataframe as Panda Series
    """

    def __init__(
        self,
        input_series: Optional[pd.Series],
        output_series: Optional[pd.Series],
        ground_truth_series: Optional[pd.Series],
        context_series: Optional[pd.Series],
        actual_tool_calls_series: Optional[pd.Series],
        reference_tool_calls_series: Optional[pd.Series],
    ):
        self.input: Optional[pd.Series] = input_series
        self.output: Optional[pd.Series] = output_series
        self.ground_truth: Optional[pd.Series] = ground_truth_series # could also be called the source
        self.context: Optional[pd.Series] = context_series
        self.actual_tool_calls: Optional[pd.Series] = actual_tool_calls_series
        self.reference_tool_calls: Optional[pd.Series] = reference_tool_calls_series

    @staticmethod
    def from_series(
        input_series: Optional[pd.Series],
        output_series: Optional[pd.Series],
        ground_truth_series: Optional[pd.Series],
        context_series: Optional[pd.Series],
        actual_tool_calls_series: Optional[pd.Series],
        reference_tool_calls_series: Optional[pd.Series],
    ):
        return GenAIMetricInput(input_series, output_series, ground_truth_series, context_series, actual_tool_calls_series, reference_tool_calls_series)

    @staticmethod
    def from_df(input_df: pd.DataFrame, recipe_desc: GenAIEvalRecipeDesc):
        return GenAIMetricInput(
            input_df[recipe_desc.input_column_name],
            input_df.get(recipe_desc.output_column_name),
            input_df.get(recipe_desc.ground_truth_column_name),
            input_df.get(recipe_desc.context_column_name),
            input_df.get(recipe_desc.actual_tool_calls_column_name),
            input_df.get(recipe_desc.reference_tool_calls_column_name),
        )

    @staticmethod
    def from_single_entry(input: str, output: str, ground_truth: Optional[str], context: Union[str, List[str]], actual_tool_calls: Optional[List], reference_tool_calls: Optional[List]):
        return GenAIMetricInput(
            pd.Series([input]),
            pd.Series([output]),
            pd.Series([ground_truth]),
            pd.Series([context]),
            pd.Series(actual_tool_calls),
            pd.Series(reference_tool_calls),
        )

    def get(self, input_role: GenAIMetricInputRole) -> Optional[pd.Series]:
        if input_role == GenAIMetricInputRole.INPUT:
            return self.input
        elif input_role == GenAIMetricInputRole.OUTPUT:
            return self.output
        elif input_role == GenAIMetricInputRole.GROUND_TRUTH:
            return self.ground_truth
        elif input_role == GenAIMetricInputRole.CONTEXT:
            return self.context
        elif input_role == GenAIMetricInputRole.ACTUAL_TOOL_CALLS:
            return self.actual_tool_calls
        elif input_role == GenAIMetricInputRole.REFERENCE_TOOL_CALLS:
            return self.reference_tool_calls
        else:
            return MetricInputException(f"Unknown GenAIMetricInputRole: {input_role}")
