import logging
from typing import Dict, Optional, Tuple

import pandas as pd
from bert_score import BERTScorer
from bert_score.utils import model2layers
from torch import Tensor

from dataiku.core.model_provider import ModelProvider
from dataiku.llm.evaluation.exceptions import BertScoreException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import mean_or_none

logger = logging.getLogger(__name__)


def compute_bert_score(
    metric_inputs: GenAIMetricInput, hf_model_id: str, hf_connection: Optional[str]
) -> Tuple[Dict[str, Optional[float]], pd.DataFrame]:
    filtered_columns = failure_utils.filter_null_rows(
        metric_inputs,
        ["BERT Score"],
        {GenAIMetricInputRole.OUTPUT, GenAIMetricInputRole.GROUND_TRUTH},
    )

    try:
        return _compute_bert_score(filtered_columns, hf_model_id, hf_connection)
    except Exception as e:
        raise BertScoreException(f"An error happened during the computation of BERT Score metrics {e}", e)


def _compute_bert_score(
    filtered_columns: GenAIMetricInput, hf_model_id: str, hf_connection: Optional[str]
) -> Tuple[Dict[str, Optional[float]], pd.DataFrame]:
    initial_index: pd.Index = filtered_columns.input.index  # Keep index for output  # type: ignore

    candidate = filtered_columns.output.to_list()
    reference = filtered_columns.ground_truth.to_list()

    bs_compatible_model_id = _get_model_name_for_bert_score(hf_model_id)
    try:
        # Note that the connection can be None, in which case requests to HF Hub will be unauthenticated, which has rate
        # limiting implications (and of course private models won't be available)
        model_path = ModelProvider().get_or_download_model(hf_model_id, hf_connection)
        logger.info(f"Running BERT-Score with cached model {hf_model_id}")
    except Exception as e:
        raise BertScoreException(f"Error downloading model {hf_model_id} from HF Hub: {e}", e)

    num_layers = model2layers[bs_compatible_model_id]
    scorer = BERTScorer(model_type=model_path, num_layers=num_layers)

    bert_score_precision: Tensor; bert_score_recall: Tensor; bert_score_f1: Tensor  # Needed to avoid type ambiguity with the return type of scorer.score
    bert_score_precision, bert_score_recall, bert_score_f1 = scorer.score(candidate, reference)  # type: ignore
    bert_score_perf = {
        "bertScorePrecision": mean_or_none(pd.Series(bert_score_precision)),
        "bertScoreRecall": mean_or_none(pd.Series(bert_score_recall)),
        "bertScoreF1": mean_or_none(pd.Series(bert_score_f1)),
    }
    logger.info(f"BERT Score results: {bert_score_perf}")

    bert_metrics = {
        "bertScorePrecision": bert_score_precision.numpy(),
        "bertScoreRecall": bert_score_recall.numpy(),
        "bertScoreF1": bert_score_f1.numpy(),
    }
    bert_metrics_df = pd.DataFrame(bert_metrics, index=initial_index)
    return bert_score_perf, bert_metrics_df


def create_empty_bert_score(metric_inputs: GenAIMetricInput) -> Tuple[Dict[str, None], pd.DataFrame]:
    return failure_utils.create_empty_metrics(metric_inputs, ["bertScorePrecision", "bertScoreRecall", "bertScoreF1"])


def _get_model_name_for_bert_score(hf_model_id: str) -> str:
    # The HuggingFace hub has some aliases/redirects to make common models easier to find. For instance,
    # bert-base-uncased is at https://huggingface.co/google-bert/bert-base-uncased on the HuggingFace Hub, but
    # https://huggingface.co/bert-base-uncased will redirect to the same page. This is an issue because HuggingFace
    # connections support custom models with user-provided ids, so a user could use either google-bert/bert-base-uncased
    # or just bert-base-uncased. Now, for some reason, bert_score.utils.model2layers contains model keys in *both
    # formats*, making bert-score actually *depend* on HF aliases. For instance, creating a BertScorer with
    # model_type="google-bert/bert-base-uncased" instead of "bert-base-uncased" will give an error if you don't manually
    # pass a num_layers argument, but passing deberta-base instead of microsoft/deberta-base will *also* give an error.
    # So, we need to do some finangling. Don't blame me, blame bert-score.
    if hf_model_id in model2layers:
        return hf_model_id

    model_id = hf_model_id.split("/")[-1]
    # Important: we should always check that the model is model2layers before using it. We mostly trust the authors of
    # bert-score to have a reasonable list of supported models here - if we allowed arbitrary models, users could
    # inadvertently download malicious code
    if model_id not in model2layers:
        logger.error(f"Could not find model name {hf_model_id}. Available models are: {model2layers.keys()}")
        raise BertScoreException(f"Model {hf_model_id} not found")

    return model_id
