import logging

import numpy as np
import pandas as pd

from dataiku.modelevaluation.drift.drift_model import train_drift_model
from dataiku.modelevaluation.drift.utils import _handle_test_error
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

logger = logging.getLogger(__name__)


class DriftEmbedding(object):
    """
    Compute drift on embeddings
    """
    embedding_ref = None
    embedding_cur = None
    handle_drift_failure_as_error = False

    def __init__(self, embedding_ref, embedding_cur, handle_drift_failure_as_error=False):
        self.embedding_ref = embedding_ref
        self.embedding_cur = embedding_cur
        self.handle_drift_failure_as_error = handle_drift_failure_as_error

    def compute_drift(self):
        column_results = {}

        for column in self.embedding_cur:
            logger.info(u"Embeddding drift: computing embedding drift for column {col}".format(col=column))
            column_results[column] = {
                "euclidianDistance": self._euclidian_distance(self.embedding_ref, self.embedding_cur, column),
                "cosineSimilarity": self._cosine_similarity(self.embedding_ref, self.embedding_cur, column),
                "classifierGini": self._classifier_gini(self.embedding_ref, self.embedding_cur, column)
            }

        return {"columns": column_results}

    def _euclidian_distance(self, ref, cur, column):
        logger.info("Computing euclidian distance metric on column %s" % column)
        try:
            mean_ref_embedding = np.mean(ref[column], axis=0)
            mean_cur_embedding = np.mean(cur[column], axis=0)
            ed = np.linalg.norm(mean_cur_embedding - mean_ref_embedding)
            logger.info("Euclidian distance : %s" % ed)
            return ed
        except Exception as err:
            return _handle_test_error("euclidian distance", column, err, self.handle_drift_failure_as_error, None)

    def _cosine_similarity(self, ref, cur, column):
        logger.info("Computing cosine similarity metric on column %s" % column)
        try:
            mean_ref_embedding = np.mean(ref[column], axis=0)
            mean_cur_embedding = np.mean(cur[column], axis=0)

            dot_product = np.dot(mean_ref_embedding.flatten(), mean_cur_embedding.flatten())
            magnitude_vector1 = np.linalg.norm(mean_ref_embedding)
            magnitude_vector2 = np.linalg.norm(mean_cur_embedding)

            cs = dot_product / (magnitude_vector1 * magnitude_vector2)
            logger.info("Cosine similarity : %s" % str(cs))
            return cs
        except Exception as err:
            return _handle_test_error("cosine similarity", column, err, self.handle_drift_failure_as_error, None)

    def _classifier_gini(self, ref, cur, column):
        logger.info("Computing classifier model metric on column %s" % column)
        ref_list = ref[column]
        cur_list = cur[column]
        try:
            concatened_df = np.concatenate([ref_list, cur_list])
            labels = pd.Series([0] * len(ref_list) + [1] * len(cur_list))

            use_reduction = len(ref_list[0]) > 30 # No need to perform reduction if the embedding vector is already small
            if use_reduction:
                logger.info("Performing PCA to reduce the embeddings dimension")
                reducer = PCA(n_components=10, random_state=1337)

                samples_for_reducer = min(1000, len(concatened_df)) # Boost up reduction speed
                samples = np.random.choice(len(concatened_df), samples_for_reducer, replace=False)
                reducer.fit(concatened_df[samples])
                concatened_df = reducer.transform(concatened_df)

            x_train, x_test, y_train, y_test = train_test_split(concatened_df, labels, stratify=labels, random_state=1337)

            clf = train_drift_model(x_train, y_train)

            y_pred = clf.predict_proba(x_test)[:, 1]
            auc = roc_auc_score(y_test, y_pred)
            gini = max(2 * auc - 1, 0)
            logger.info("Model AUC : %s" % str(auc))
            logger.info("Classifier gini : %s" % str(gini))
            return gini

        except Exception as err:
            return _handle_test_error("classifier gini", column, err, self.handle_drift_failure_as_error, None)
