# coding: utf-8
from __future__ import unicode_literals

import pandas as pd

from dataiku.core import doctor_constants
from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.diagnostics import diagnostics
import scipy.stats as sps


PVALUE_THRESHOLD = 0.05


class DatasetSanityCheckDiagnostic(diagnostics.DiagnosticCallback):
    """ See in the documentation machine-learning/diagnostics.html#dataset-sanity-checks """
    def __init__(self):
        super(DatasetSanityCheckDiagnostic, self).__init__(diagnostics.DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS)
        self.train_feature_counts = None
        self.test_feature_counts = None

    def on_load_train_dataset_end(self, prediction_type=None, df=None, target_variable=None):
        diagnostics = []
        self.check_train_dataset(df, diagnostics)
        if target_variable is not None and prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS,
                                                               doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION):
            self.train_feature_counts = self.check_balance(df[target_variable], "train", diagnostics)
        return diagnostics

    def on_load_test_dataset_end(self, prediction_type=None, df=None, target_variable=None):
        diagnostics = []
        self.check_test_dataset(df, diagnostics)
        if target_variable is not None and prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS,
                                                               doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION):
            self.test_feature_counts = self.check_balance(df[target_variable], "test", diagnostics)
        return diagnostics

    def on_load_evaluation_dataset_end(self, df=None, univariate_drift=None, per_feature=None, target_column=None, target_remapping=None, prediction_type=None):
        diagnostics = []
        self.check_evaluation_dataset(df, diagnostics, target_column, prediction_type)
        self.check_evaluation_new_values(univariate_drift, per_feature, diagnostics)
        self.check_evaluation_new_target_cats(df, target_column, target_remapping, prediction_type, diagnostics)
        return diagnostics

    def on_class_definition(self, prediction_column_values=None, target_column_values=None, classes=None):
        diagnostics = []

        if classes is None or len(classes) == 0:
            self.check_predictions_against_target_column(prediction_column_values, target_column_values, diagnostics)
        else:
            self.check_predictions_against_user_defined_class(prediction_column_values, classes, diagnostics)
            self.check_user_defined_classes_against_target_column(classes, target_column_values, diagnostics)
        return diagnostics

    def on_scoring_end(self, scoring_params=None, transformed_test=None, transformed_train=None, with_sample_weight=False):
        diagnostics = []
        if scoring_params is not None and scoring_params.prediction_type == doctor_constants.REGRESSION and \
                transformed_train is not None and transformed_test is not None:
            series1 = transformed_train["target"]
            series2 = transformed_test["target"]
            statistic, pvalue = sps.ks_2samp(series1, series2)
            if pvalue < PVALUE_THRESHOLD:
                diagnostics.append("Target variable distribution in test data does not match the training data distribution (p-value={:.3f}), metrics could be misleading".format(pvalue))
        elif scoring_params is not None and scoring_params.prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION] \
                and self.train_feature_counts is not None and self.test_feature_counts is not None:

            if len(self.train_feature_counts) != len(self.test_feature_counts):
                diagnostics.append("Test and train dataset do not contain the same number of classes")
            else:
                train_features_total = self.train_feature_counts.sum()
                test_features_total = self.test_feature_counts.sum()
                weight = test_features_total / float(train_features_total)
                _, pvalue = sps.chisquare(self.train_feature_counts * weight, f_exp=self.test_feature_counts.reindex(self.train_feature_counts.index))
                if pvalue < PVALUE_THRESHOLD:
                    diagnostics.append("Target variable distribution in test data does not match the training data distribution (p-value={:.3f}), metrics could be misleading".format(pvalue))
        return diagnostics

    def on_processing_all_kfold_end(self, prediction_type=None, folds=None, with_sample_weight=False, perf_data=None):
        if folds is None:
            return []
        for fold in folds:
            diagnostics = self.on_scoring_end(with_sample_weight=with_sample_weight, **fold)
            # Don't spam the user with several diagnostics of the same type
            if len(diagnostics) > 0:
                return diagnostics
        return []

    @staticmethod
    def check_balance(serie, kind, diagnostics):
        counts = serie.value_counts(dropna=False)
        classes_count = counts.shape[0]
        if classes_count == 2:
            imbalanced_threshold = .5
            # Balance < .5 only works well for 2 classes, otherwise we need to use some kind of heuristics
            balance = sps.entropy(counts, base=2)
            if balance < imbalanced_threshold:
                msg = "The {} dataset is imbalanced (balance={:.2f}), metrics can be misleading".format(kind, balance)
                diagnostics.append(msg)
        elif classes_count > 2:
            min_norm = counts[-1] / float(counts.sum())
            uniform_distribution_threshold = 1. / 5
            if classes_count * min_norm < uniform_distribution_threshold:
                percent = min_norm * 100
                perfectly_balanced = (1. / classes_count) * 100.
                min_class = counts.index[-1]  # serie.value_counts() sort in descending order, the min is the last one
                msg = "The {} dataset is imbalanced (class '{}' is only represented in {:.2f}% of rows; a well balanced dataset would contain ~{:.2f}%)," \
                      " metrics can be misleading".format(kind, safe_unicode_str(min_class), percent, perfectly_balanced)
                diagnostics.append(msg)
        return counts

    @staticmethod
    def check_train_dataset(train_df, diagnostics):
        size = train_df.shape[0]
        if size <= 1000:
            message = "Training set might be too small ({} rows) for robust training".format(size)
            diagnostics.append(message)

    @staticmethod
    def check_test_dataset(test_df, diagnostics):
        size = test_df.shape[0]
        if size <= 1000:
            message = "Test set might be too small ({} rows) for reliable performance estimation".format(size)
            diagnostics.append(message)

    @staticmethod
    def check_evaluation_dataset(eval_df, diagnostics, target_column, prediction_type):
        size = eval_df.shape[0]
        if size <= 1000:
            message = "Evaluation set might be too small ({} rows) for reliable performance estimation".format(size)
            diagnostics.append(message)

        # target column might not be in the eval dataset (when perf metrics are skipped)
        if target_column and target_column in eval_df.columns and DatasetSanityCheckDiagnostic._is_categorical(prediction_type):
            nb_target_values = len(eval_df[target_column].unique())
            if nb_target_values == 1:
                diagnostics.append("The evaluation dataset contains only one label in the target column {} : some metrics might not get successfully computed"
                                   .format(target_column))

    @staticmethod
    def check_evaluation_new_values(univariate_drift, per_feature, diagnostics):
        if univariate_drift is not None and per_feature is not None:
            for column, column_params in per_feature.items():
                if column_params.get("type") != "CATEGORY" or column_params.get("role") != "INPUT":
                    continue
                if column not in univariate_drift.keys():
                    continue
                new_values_percentage = univariate_drift[column].get("newValuesPercentage", 0)
                if new_values_percentage is None:
                    message = "Univariate drift computation failed for categorical column {col}".format(col=column)
                    diagnostics.append(message)
                elif new_values_percentage > 0:
                    message = ("New values found in evaluation data that were not present in reference data for categorical column {col}. These rows are "
                               "excluded from univariate drift analysis computation (PSI & CHI2) but are included in performance metrics".format(col=column))
                    diagnostics.append(message)

    @staticmethod
    def check_evaluation_new_target_cats(eval_df, target_column, target_remapping, prediction_type, diagnostics):
        can_check = (target_column and target_column in eval_df.columns) and DatasetSanityCheckDiagnostic._is_categorical(prediction_type) and target_remapping
        if can_check:
            eval_target_values = set(eval_df[target_column].unique())
            train_target_values = set([item["sourceValue"] for item in target_remapping])
            new_values = list(eval_target_values - train_target_values)

            if pd.isna(new_values).any():
                message = "There are null values in the target column of the evaluation data. These rows have been excluded from evaluation."
                diagnostics.append(message)
                # Filter out the na values
                new_values =  [v for v in new_values if not pd.isna(v)]
            if len(new_values) > 0:
                message = ("There are new categories in the target column of the evaluation data that were not present during training. "
                           "These rows have been excluded from evaluation.")
                diagnostics.append(message)

    @staticmethod
    def check_predictions_against_user_defined_class(prediction_column_values, class_values, diagnostics):
        if prediction_column_values is not None and class_values is not None and len(class_values) > 0:
            intersection = len(set(class_values).intersection(prediction_column_values))
            nb_classes = len(class_values)
            if nb_classes > intersection:
                if intersection > 0 and nb_classes > 1:
                    message = "Prediction uses {} of the {} configured classes.".format(intersection, nb_classes)
                    diagnostics.append(message)
                elif intersection == 0 and nb_classes > 1:
                    message = "Prediction uses none of the {} configured classes.".format(nb_classes)
                    diagnostics.append(message)
                elif intersection == 0 and nb_classes == 1:
                    message = "Prediction doesn't use the configured class."
                    diagnostics.append(message)

    @staticmethod
    def check_predictions_against_target_column(prediction_column_values, target_column_values, diagnostics):
        if prediction_column_values is not None and target_column_values is not None:
            intersection = len(set(target_column_values).intersection(prediction_column_values))
            nb_labels = len(target_column_values)
            if nb_labels > intersection:
                if intersection > 0 and nb_labels > 1:
                    message = "Prediction uses {} of the {} classes in labels.".format(intersection, nb_labels)
                    diagnostics.append(message)
                elif intersection == 0 and nb_labels > 1:
                    message = "Prediction uses none of the {} classes in labels.".format(nb_labels)
                    diagnostics.append(message)
                elif intersection == 0 and nb_labels == 1:
                    message = "Prediction doesn't use the class in labels."
                    diagnostics.append(message)

    @staticmethod
    def check_user_defined_classes_against_target_column(class_values, target_column_values, diagnostics):
        if class_values is not None and len(class_values) > 0 and target_column_values is not None:
            intersection = len(set(target_column_values).intersection(class_values))
            nb_labels = len(target_column_values)
            if nb_labels > intersection:
                if intersection > 0 and nb_labels > 1:
                    message = "Class definition uses {} of the {} classes in labels.".format(intersection, nb_labels)
                    diagnostics.append(message)
                elif intersection == 0 and nb_labels > 1:
                    message = "Class definition uses none of the {} classes in labels.".format(nb_labels)
                    diagnostics.append(message)
                elif intersection == 0 and nb_labels == 1:
                    message = "Class definition doesn't use the class in labels."
                    diagnostics.append(message)

    @staticmethod
    def _is_categorical(prediction_type):
        return prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS]
