import copy
import logging
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from dataiku.base.utils import safe_unicode_str

from dataiku.core import doctor_constants
from dataiku.base.utils import safe_convert_to_string

preproc_logger = logging.getLogger(doctor_constants.PREPROCESSING_LOGGER_NAME)
preproc_logger.setLevel(logging.DEBUG)


class CategoricalEncoderBase(object):
    """
    Base class for categorical feature encoders.

    Attributes:
        - pd.DataFrame encoding_map: mapping between feature's categorical values and encoded values.
            it stores the unique categories as index and the different values as columns. Every column
            corresponds to a separate target value (in the case of MULTICLASS classification) and/or k-fold.
        - pd.Series category_counts: counts of the feature's unique categorical values
    """
    __slots__ = ("encoding_map", "category_counts")

    DEFAULT_VALUE = "_default_"
    NA = "_NA_"

    def __init__(self):
        self.encoding_map = None
        self.category_counts = None

    @property
    def is_fitted(self):
        return self.encoding_map is not None

    @property
    def default_value(self):
        return self.encoding_map.loc[CategoricalEncoderBase.DEFAULT_VALUE]

    def transform(self, series):
        self.encoding_map.index = safe_convert_to_string(pd.Series(self.encoding_map.index))

        # sc-84522 : Fix default and na value for scoring with models trained with DSS<10
        if "__NULL__" in self.encoding_map.index or "__default__" in self.encoding_map.index:
            na_value = "__NULL__"
            default_value = "__default__"
        else:
            na_value = CategoricalEncoderBase.NA
            default_value = CategoricalEncoderBase.DEFAULT_VALUE

        # Consider na as a category in itself.
        series.fillna(na_value, inplace=True)
        series = safe_convert_to_string(series)

        # Replace values unseen at train time by default value
        cat_values = series.unique()
        cat_values_fitted = self.encoding_map.index
        unfound_values = set(cat_values) - set(cat_values_fitted)
        for v in unfound_values:
            preproc_logger.debug("Found value %s not in map", v)
        if unfound_values:
            series.replace(list(unfound_values), default_value, inplace=True)
        df = pd.DataFrame({"__feature": series})

        # Merge series with encoding map to create encoded data frame
        # More precisely:
        #     - left_on is the column to join on (i.e. df's __feature column)
        #     - right_index indicates that we use the right DataFrame as the join key (i.e. the encoding map)
        #     - how: use only keys from left frame, similar to a SQL left outer join; preserve key order
        result_df = df.merge(self.encoding_map, left_on="__feature", right_index=True, how='left')
        del result_df["__feature"]
        return result_df

    def _update_encoding_map(self, other_mapping):
        """
        Update encoding map - it can store values for multiple target classes
        :param pd.DataFrame other_mapping: mapping to use for encoding_map update
        """
        if self.encoding_map is None:
            self.encoding_map = other_mapping
        else:
            self.encoding_map = self.encoding_map.merge(other_mapping, left_index=True, right_index=True, how='outer')


class WrappingCategoricalEncoder(CategoricalEncoderBase):
    """
    Categorical feature encoder that wraps another one
    """
    __slots__ = ("encoder",)

    def __init__(self, encoder):
        """
        :param encoder: categorical feature encoder to wrap
        :type encoder: CategoricalEncoderBase
        """
        super(WrappingCategoricalEncoder, self).__init__()
        self.encoder = encoder

    def get_reportable_map(self):
        # Take counts and the encoding map to create report
        fullmap = pd.concat([self.category_counts, self.encoding_map], axis=1, sort=False)
        fullmap.sort_values(by="counts", ascending=False, inplace=True)
        return fullmap


class CategoricalSimpleEncoder(WrappingCategoricalEncoder):
    """
    Simple (no KFold) wrapper for categorical feature encoding.
    Adds mapping column naming, and encoding_map merging mechanism for multiclass,
    as well as reporting of the encoding_map and category_counts.
    """

    def fit(self, series, target_series, target_val=None):
        # Consider na as a category in itself.
        series.fillna(CategoricalEncoderBase.NA, inplace=True)

        # Save the counts for the reportable map
        self.category_counts = series.value_counts()
        self.category_counts.name = "counts"

        # Compute the categorical coding
        self.encoder.fit(series, target_series)
        mapping = self.encoder.encoding_map
        if target_val is not None:
            # Classification
            mapping.columns = ["target_value:" + safe_unicode_str(target_val)]
        else:
            mapping.columns = ["all_target_values"]

        self._update_encoding_map(mapping)

    def fit_transform(self, series, target_series, target_val=None):
        self.fit(series, target_series, target_val)
        # Transform the series
        return self.encoder.transform(series)


class CategoricalKFoldEncoder(WrappingCategoricalEncoder):
    """
    KFold wrapper for categorical feature encoding
    """

    __slots__ = ('prediction_type', 'k', 'seed')

    def __init__(self, encoder, prediction_type, k=5, seed=1337):
        """
        :param encoder: categorical feature encoder to wrap
        :type encoder: CategoricalEncoderBase
        :param prediction_type: Prediction task type (binary, multiclass, regression)
        :type prediction_type: str
        :param k: number of folds (default=5, same as split params default)
        :type k: int
        :param seed: random seed, for reproducibility (default=1337, same as split params default)
        :type seed: int
        """
        super(CategoricalKFoldEncoder, self).__init__(encoder)
        self.prediction_type = prediction_type
        self.k = k
        self.seed = seed

    def fit(self, series, target_series, target_val=None):
        # You should use the fit_transform method which handles fitting and transforming on separate splits
        # and thus avoids leaking the encoded row into the transformed feature as in CategoricalSimpleEncoder
        raise NotImplementedError("For CategoricalKFoldEncoder use fit_transform")

    def fit_transform(self, series, target_series, target_val=None):
        # Consider na as a category in itself.
        series.fillna(CategoricalEncoderBase.NA, inplace=True)

        # Save the counts for the reportable map
        self.category_counts = series.value_counts()
        self.category_counts.name = "counts"

        feature_name = series.name or "feature"
        df = pd.DataFrame({feature_name: series, "target": target_series})

        impact_coded = []
        global_mean = target_series.mean()

        mapping_computations = []

        # Only use stratified kfold for classification tasks
        if self.prediction_type in {doctor_constants.BINARY_CLASSIFICATION,
                                    doctor_constants.MULTICLASS}:
            cv = StratifiedKFold(n_splits=self.k, shuffle=True, random_state=self.seed)
        elif self.prediction_type == doctor_constants.REGRESSION:
            cv = KFold(n_splits=self.k, shuffle=True, random_state=self.seed)
        else:
            raise ValueError("Prediction type not matching allowed types")

        for split, (infold, oofold) in enumerate(cv.split(df[feature_name], df["target"])):
            # Fit using infold data
            feature_encoder = copy.deepcopy(self.encoder)
            feature_encoder.fit(df[feature_name].iloc[infold], df["target"].iloc[infold])

            # Use encoder to transform oof data
            impact_coded_oof = feature_encoder.transform(df[feature_name].iloc[oofold])
            impact_coded.append(impact_coded_oof)

            # Save the encoding map for every fold
            mapping = feature_encoder.encoding_map
            if target_val is not None:
                mapping.columns = ["target_value:" + safe_unicode_str(target_val) + "_fold_" + safe_unicode_str(split)]
            else:
                mapping.columns = ["all_target_values_fold_" + safe_unicode_str(split)]
            mapping_computations.append(mapping)

        mapping_computation_df = pd.concat(mapping_computations, axis=1, sort=True)

        # Use the global mean to replace all values which did not appear in a given fold
        mapping_computation_df.fillna(global_mean, inplace=True)

        # Mean mapping will be used to transform validation / test data
        mean_mapping = pd.DataFrame(mapping_computation_df.mean(axis=1))
        if target_val is not None:
            # Classification
            mean_mapping.columns = ["target_value:" + safe_unicode_str(target_val)]
        else:
            mean_mapping.columns = ["all_target_values"]

        self._update_encoding_map(mean_mapping)

        # Add mean mapping to full table of encodings for the reporting
        if target_val is not None:
            # Classification
            mapping_computation_df["target_value:" + safe_unicode_str(target_val) + "_mean"] = mean_mapping
        else:
            mapping_computation_df["all_target_values_mean"] = mean_mapping

        # Resort encoded data by index so that it has the same index has the original feature
        impact_coded_df = pd.concat(impact_coded, axis=0).sort_index()
        if target_val is not None:
            # Classification
            impact_coded_df.columns = ["target_value:" + safe_unicode_str(target_val)]
        else:
            impact_coded_df.columns = ["all_target_values"]
        return impact_coded_df
