import logging

import numpy as np
import pandas as pd

from dataiku.doctor.prediction.regression_scoring import RegressionModelScorer
from dataiku.modelevaluation.drift.drift_model import train_drift_model
from dataiku.modelevaluation.drift.utils import _handle_test_error
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

logger = logging.getLogger(__name__)


class DriftEmbedding(object):
    """
    Compute drift on embeddings
    """

    def __init__(self, embedding_ref, embedding_cur, handle_drift_failure_as_error=False):
        """
        :param embedding_ref: Reference data embeddings
        :type embedding_ref: dict
        :param embedding_cur: Current data embeddings.
        :type embedding_cur: dict
        :param handle_drift_failure_as_error: If True, raises an exception on a calculation failure.
        :type handle_drift_failure_as_error: bool
        """
        self.embedding_ref = embedding_ref
        self.embedding_cur = embedding_cur
        self.handle_drift_failure_as_error = handle_drift_failure_as_error

    def compute_drift(self):
        """
        Computes drift metrics for each embedding column.

        :return: metric results for each column.
        :rtype: dict
        """
        column_results = {}

        for column in self.embedding_cur:
            logger.info("Embedding drift: computing embedding drift for column {col}".format(col=column))
            y_test, y_prob_of_1, id_test = self._compute_classifier(self.embedding_ref, self.embedding_cur, column)
            column_results[column] = self._compute_core_metrics(column, y_test, y_prob_of_1)
        return {"columns": column_results}

    def compute_drift_for_image(self, ref_df, cur_df, ref_folder_name, cur_folder_name):
        """
        Computes extended drift metrics for image embedding.

        :param ref_df: Reference Dataset.
        :type ref_df: pd.DataFrame
        :param cur_df: Current Dataset.
        :type cur_df: pd.DataFrame
        :param ref_folder_name: Smart Name ID of the reference Managed Folder.
        :type ref_folder_name: str
        :param cur_folder_name: Smart Name ID of the current Managed Folder.
        :type cur_folder_name: str
        :return: metrics and prediction info for each column.
        :rtype: dict
        """
        column_results = {}

        for column in self.embedding_cur:
            logger.info("Embedding drift: computing embedding drift for column {col}".format(col=column))
            y_test, y_prob_of_1, id_test = self._compute_classifier(self.embedding_ref, self.embedding_cur, column)
            metrics = self._compute_core_metrics(column, y_test, y_prob_of_1)
            metrics["refPredictionInfos"] = self._get_class_prediction_infos(
                y_test, y_prob_of_1, id_test, ref_df, ref_folder_name, for_class=0
            )
            metrics["curPredictionInfos"] = self._get_class_prediction_infos(
                y_test, y_prob_of_1, id_test, cur_df, cur_folder_name, for_class=1
            )
            column_results[column] = metrics
        return {"columns": column_results}

    def _compute_core_metrics(self, column, y_test, y_prob_of_1):
        """Computes the set of metrics common to all drift computations.

        :param column: Name of the embedding column to process.
        :type column: str
        :param y_test: Actual labels (0 or 1).
        :type y_test: pd.Series
        :param y_prob_of_1: Predicted probabilities for class 1 from the test set.
        :type y_prob_of_1: np.ndarray
        :return: A dictionary containing the core metrics.
        :rtype: dict
        """
        return {
            "euclidianDistance": self._euclidian_distance(self.embedding_ref, self.embedding_cur, column),
            "cosineSimilarity": self._cosine_similarity(self.embedding_ref, self.embedding_cur, column),
            "classifierGini": self._classifier_gini(y_test, y_prob_of_1, column),
        }

    def _euclidian_distance(self, ref, cur, column):
        """
        Calculates the Euclidean distance between the mean vectors of the reference and current embeddings.

        :param ref: Reference embeddings.
        :type ref: dict
        :param cur: Current embeddings.
        :type cur: dict
        :param column: Name of the embedding column to process.
        :type column: str
        :return: The calculated Euclidean distance (or None on error).
        :rtype: float or None
        """
        logger.info("Computing euclidian distance metric on column %s" % column)
        try:
            mean_ref_embedding = np.mean(ref[column], axis=0)
            mean_cur_embedding = np.mean(cur[column], axis=0)
            ed = np.linalg.norm(mean_cur_embedding - mean_ref_embedding)
            logger.info("Euclidian distance : %s" % ed)
            return ed
        except Exception as err:
            return _handle_test_error(
                "euclidian distance",
                column,
                err,
                self.handle_drift_failure_as_error,
                None,
            )

    def _cosine_similarity(self, ref, cur, column):
        """
        Calculates the cosine similarity between the mean vectors of the reference and current embeddings.

        :param ref: Reference embeddings.
        :type ref: dict
        :param cur: Current embeddings.
        :type cur: dict
        :param column: Name of the embedding column to process.
        :type column: str
        :return: The calculated cosine similarity (or None on error).
        :rtype: float or None
        """
        logger.info("Computing cosine similarity metric on column %s" % column)
        try:
            mean_ref_embedding = np.mean(ref[column], axis=0)
            mean_cur_embedding = np.mean(cur[column], axis=0)

            dot_product = np.dot(mean_ref_embedding.flatten(), mean_cur_embedding.flatten())
            magnitude_vector1 = np.linalg.norm(mean_ref_embedding)
            magnitude_vector2 = np.linalg.norm(mean_cur_embedding)

            cs = dot_product / (magnitude_vector1 * magnitude_vector2)
            logger.info("Cosine similarity : %s" % str(cs))
            return cs
        except Exception as err:
            return _handle_test_error(
                "cosine similarity",
                column,
                err,
                self.handle_drift_failure_as_error,
                None,
            )

    def _compute_classifier(self, ref, cur, column):
        """
        Trains a drift classifier to distinguish reference embeddings from current embeddings.

        :param ref: Dictionary of reference embeddings.
        :type ref: dict
        :param cur: Dictionary of current embeddings.
        :type cur: dict
        :param column: Name of the embedding column to process.
        :type column: str
        :return: A tuple containing test set actual labels, test set predicted labels for class 1,
                and the test sample IDs.
        :rtype: tuple(pd.Series, np.ndarray, pd.Series) or tuple(None, None, None)
        """
        logger.info("Computing classifier model metric on column %s" % column)
        ref_list = ref[column]
        cur_list = cur[column]
        try:
            concatenated_df = np.concatenate([ref_list, cur_list])
            labels = pd.Series([0] * len(ref_list) + [1] * len(cur_list))
            ids = pd.Series([i for i in range(len(ref_list))] + [i for i in range(len(cur_list))])

            # No need to perform reduction if the embedding vector is already small
            use_reduction = len(ref_list[0]) > 30
            if use_reduction:
                logger.info("Performing PCA to reduce the embeddings dimension")
                reducer = PCA(n_components=10, random_state=1337)

                samples_for_reducer = min(1000, len(concatenated_df))  # Boost up reduction speed
                samples = np.random.choice(len(concatenated_df), samples_for_reducer, replace=False)
                reducer.fit(concatenated_df[samples])
                concatenated_df = reducer.transform(concatenated_df)

            x_train, x_test, y_train, y_test, _, id_test = train_test_split(
                concatenated_df, labels, ids, stratify=labels, random_state=1337
            )

            clf = train_drift_model(x_train, y_train)
            y_prob_of_1 = clf.predict_proba(x_test)[:, 1]
            return y_test, y_prob_of_1, id_test

        except Exception as err:
            return _handle_test_error("classifier gini", column, err, self.handle_drift_failure_as_error, None)

    def _classifier_gini(self, y_test, y_prob_of_1, column):
        """
        Calculates a modified Gini score from the drift classifier's test set predictions.

        :param y_test: Actual labels (0 or 1) from the test set.
        :type y_test: pd.Series
        :param y_prob_of_1: Predicted probabilities for class 1 from the test set.
        :type y_prob_of_1: np.ndarray
        :param column: Name of the column.
        :type column: str
        :return: The calculated Gini score (or None on error).
        :rtype: float or None
        """
        logger.info("Computing classifier gini metric on column %s" % column)
        try:
            auc = roc_auc_score(y_test, y_prob_of_1)
            gini = max(2 * auc - 1, 0)
            logger.info("Model AUC : %s" % str(auc))
            logger.info("Classifier gini : %s" % str(gini))
            return gini

        except Exception as err:
            return _handle_test_error("classifier gini", column, err, self.handle_drift_failure_as_error, None)

    def _get_class_prediction_infos(self, y_test, y_prob_of_1, id_test, df, folder_name, for_class):
        """
        Extracts prediction details for a specific target class.

        :param y_test: Actual labels (0 or 1) from the test set.
        :type y_test: pd.Series
        :param y_prob_of_1: Predicted probabilities for class 1 from the test set.
        :type y_prob_of_1: np.ndarray
        :param id_test: The iloc-based IDs corresponding to the test samples.
        :type id_test: pd.Series
        :param df: Original Dataset.
        :type df: pd.DataFrame
        :param folder_name: Name of the Managed Folder.
        :type folder_name: str
        :param for_class: Target origin label (0 for ref, 1 for cur).
        :type for_class: int
        :return: A dictionary with PDF data, top 5 samples, and folder name.
        :rtype: dict
        """
        mask = y_test == for_class
        y_prob_of_1_for_class = y_prob_of_1[mask]
        ids = id_test[mask]

        top_5_values = self._extract_top_5_values(df.iloc[ids], y_prob_of_1_for_class, for_class == 0)
        prediction_infos = RegressionModelScorer.compute_preds_pdf(y_prob_of_1_for_class, xmin=0, xmax=1)

        return {
            "x": list(prediction_infos["x"]),
            "pdf": list(prediction_infos["pdf"]),
            "top5": top_5_values.to_dict(orient="records"),
            "managedFolderSmartName": folder_name,
        }

    def _extract_top_5_values(self, df, values, is_ascending, predict_column_name="predictions"):
        """
        Finds the top/bottom 5 rows from a DataFrame based on an external list of prediction values.

        :param df: Dataset.
        :type df: pd.DataFrame
        :param values: The list of prediction values corresponding to the Dataset.
        :type values: np.ndarray
        :param is_ascending: Sort direction.
        :type is_ascending: bool
        :param predict_column_name: The name for the new column of prediction scores.
        :type predict_column_name: str
        :return: A new Dataset containing the top 5 rows with their prediction scores.
        :rtype: pd.DataFrame
        """
        if len(df) != len(values):
            raise ValueError("Inconsistency in lengths: expect {0} for len(df), got {1}.".format(len(values), len(df)))

        values_series = pd.Series(values, index=df.index, name=predict_column_name)

        if is_ascending:
            top_5_series = values_series.nsmallest(5)
        else:
            top_5_series = values_series.nlargest(5)

        top_5_df = df.loc[top_5_series.index]
        top_5_df = top_5_df.assign(**{predict_column_name: top_5_series})
        return top_5_df.reset_index(drop=True)
