import numpy as np
import pandas as pd
import logging

from dataiku.base.folder_context import build_folder_context
from dataiku.core import doctor_constants
from dataiku.core.dku_logging import LogLevelContext
from dataiku.doctor.posttraining.features_distribution import NumericFeatureDistributionComputer
from dataiku.doctor.posttraining.features_distribution import CategoricalFeatureDistributionComputer
from dataiku.doctor.posttraining.model_information_handler import build_model_handler
from dataiku.core.percentage_progress import PercentageProgress
from dataiku.doctor.prediction.common import check_classical_prediction_type
from dataiku.doctor.prediction.overrides.ml_overrides_params import OVERRIDE_INFO_COL
from dataiku.doctor.utils.metrics import log_odds

OTHERS_NAME = "__DKU_OTHERS__"
UNREPRESENTED_MODALITY_NAME = "__DKU_UNREPRESENTED__"

logger = logging.getLogger(__name__)


def compute(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, modellike_folder, computation_params, fmi):
    if computation_params is None or "features_to_compute" not in computation_params:
        raise Exception("'computation_params' should contains a key 'features_to_compute'")

    model_handler = build_model_handler(split_desc, core_params, preprocessing_folder, model_folder,
                                        split_folder, fmi, postcompute_folder=None)

    modellike_folder_context = (build_folder_context(modellike_folder) if modellike_folder
                                else model_handler.get_model_folder_context())

    prediction_type = model_handler.get_prediction_type()
    check_classical_prediction_type(prediction_type)

    features_to_compute = computation_params["features_to_compute"]
    debug_mode = computation_params.get("debug_mode", False)
    sample_size = computation_params.get("sample_size", 10000)
    random_state = computation_params.get("random_state", 1337)

    if model_handler.use_full_df():
        df, _ = model_handler.get_full_df()
    else:
        df, _ = model_handler.get_test_df()

    if sample_size < df.shape[0]:
        on_sample = True
        nb_records = sample_size
    else:
        on_sample = False
        nb_records = df.shape[0]
    df = df.sample(nb_records, random_state=random_state).reset_index(drop=True)
    progress = PartialDependenciesProgress(job_id, len(features_to_compute))
    saver = PartialDependenciesSaver(modellike_folder_context, model_handler.get_schema())
    computer = PartialDependencyComputer(df, model_handler, progress, debug_mode)

    for index, feature_name in enumerate(features_to_compute):
        drop_missing = model_handler.get_per_feature_col(feature_name).get("missing_handling") == "DROP_ROW"
        feature_type = model_handler.get_type_of_column(feature_name)
        is_dummified = False
        category_possible_value = None
        if feature_type == 'CATEGORY':
            is_dummified = model_handler.is_column_dummified(feature_name)
            category_possible_value = model_handler.category_possible_values(feature_name)
        pd_feature = PartialDependencyFeature(feature_type,
                                              feature_name,
                                              is_dummified,
                                              category_possible_value,
                                              drop_missing)
        result = computer.compute(pd_feature)
        saver.save(result, on_sample, nb_records, random_state)
        progress.set_percentage((index + 1) * 100 / len(features_to_compute))


class PartialDependenciesProgress(PercentageProgress):
    def __init__(self, future_id, number_of_features):
        PercentageProgress.__init__(self, future_id)
        self.number_of_features = number_of_features

    def set_percentage_for_single_computation(self, percentage, no_fail=True):
        if self.number_of_features == 1:
            self.set_percentage(percentage, no_fail=no_fail)


class PartialDependencyComputer:

    def __init__(self, df, model_handler, progress, debug_mode=False, max_cats=30, batch_size=500e6):
        """
        :param float batch_size: in bytes
        """

        self.prediction_type = model_handler.get_prediction_type()
        self.model_handler = model_handler

        self.dataframe = df
        self.progress = progress
        self.max_cats = max_cats
        self.batch_size = batch_size

        preprocessing_log_level = logging.DEBUG if debug_mode else logging.INFO
        self.log_level_context = LogLevelContext(preprocessing_log_level, [doctor_constants.PREPROCESSING_LOGGER_NAME])

        self.n_samples = self.dataframe.shape[0]

        if model_handler.get_sample_weight_variable() is not None:
            # Replace nan weights by zero because in the preprocessing steps,
            # rows with missing weights are dropped
            self.sample_weights = np.nan_to_num(df[model_handler.get_sample_weight_variable()].values)
            self.weighted_samples = np.sum(self.sample_weights)
        else:
            self.sample_weights = None
            self.weighted_samples = self.n_samples

        self.base_prediction = self._predict_and_get_pd_value(df.copy())

    def compute(self, pd_feature):
        if pd_feature.type == "NUMERIC":
            return self._compute_numeric(pd_feature)
        elif pd_feature.type == "CATEGORY":
            return self._compute_category(pd_feature)

    def _compute_numeric(self, pd_feature):
        sample_weight_without_nans = np.nan_to_num(self.sample_weights)
        feature_distribution_computer = NumericFeatureDistributionComputer()
        feature_distribution_computer.compute(self.dataframe[pd_feature.name], sample_weight_without_nans)
        scale, distribution = feature_distribution_computer.get_histograms()
        indices_to_drop = []

        partial_dep = self._compute_partial_dependency(scale, pd_feature, indices_to_drop)

        partial_dep = partial_dep - self.base_prediction
        partial_dep = partial_dep.transpose()

        if self.prediction_type in {doctor_constants.REGRESSION, doctor_constants.BINARY_CLASSIFICATION}:
            # Always use a 2D array
            partial_dep = partial_dep[np.newaxis]

        return PartialDependencyResult(pd_feature, scale, distribution, partial_dep, indices_to_drop=indices_to_drop)

    def _compute_partial_dependency(self, values, pd_feature, indices_to_drop):
        partial_dep = []

        # Creation of the batch dataframe
        batched_dfs = pd.concat([self.dataframe] * values.shape[0], ignore_index=True)
        repeated_values = np.repeat(values, self.dataframe.values.shape[0])
        batched_dfs[pd_feature.name] = repeated_values
        self.progress.set_percentage_for_single_computation(20)

        # Computation requires predicting `nb_modalities * rows_in_test_set` rows.
        # So, we `predict_by_batch` not to overload memory usage.
        batch_preds = self._predict_by_batch(batched_dfs,
                                             progress=lambda prog: self.progress.set_percentage_for_single_computation(20 + 80 * prog))

        # Get pd values
        with self.log_level_context:
            nb_rows = self.dataframe.shape[0]
            for batch in range(len(values)):
                pd_value = self._get_pd_value(batch_preds.iloc[batch * nb_rows: (batch + 1) * nb_rows].reset_index(drop=True).dropna()) # We drop the index to match the weight index, see sc-145951
                if pd_value is None:
                    indices_to_drop.append(batch)
                    pd_value = self.base_prediction
                partial_dep.append(pd_value)

        return np.array(partial_dep)

    def _compute_category(self, pd_feature):
        unrepresented_modalities = []
        indices_to_drop = []
        column = self.dataframe[pd_feature.name].fillna(doctor_constants.FILL_NA_VALUE)  # For the PDP, we replace nans by a str
        sample_weight_without_nans = np.nan_to_num(self.sample_weights)
        feature_distribution_computer = CategoricalFeatureDistributionComputer()
        feature_distribution_computer.compute(column, sample_weight_without_nans)
        scale, distribution = feature_distribution_computer.get_values_with_nans()
        scale_size = scale.shape[0]

        # Sorting the distribution and the scale, putting more frequent modalities first
        indices = np.argsort(-distribution)
        scale = scale[indices]
        distribution = distribution[indices]
        unrepresented_pd_value = self._predict_and_get_pd_value_for_unrepresented_modality(pd_feature)

        # Prunes all not known modality + drop rows that we already know their partial dependence
        computations_to_make, computations_to_make_indexes, partial_dep_dict = self._get_computations_to_make_categorical(
            indices_to_drop, pd_feature, scale, unrepresented_modalities, unrepresented_pd_value)

        if len(computations_to_make) > 0:
            for idx, partial_dependency in zip(computations_to_make_indexes,
                                               self._compute_partial_dependency(computations_to_make,
                                                                                pd_feature, indices_to_drop)):
                partial_dep_dict[idx] = partial_dependency

        partial_dep = np.array([partial_dep_dict[idx] for idx in range(scale_size)])

        if partial_dep.shape[0] > self.max_cats:
            partial_dep, scale, distribution = self.aggregate_less_frequent_values(partial_dep, scale, distribution)

        if pd_feature.is_dummified:
            # Add a fake modality that represents a modality the model doesn't know
            # It's only used to compare with others modalities and it's not present in the test
            # so its distribution is zero
            partial_dep = np.append(partial_dep, [unrepresented_pd_value], axis=0)
            distribution = np.append(distribution, 0.0)
            scale = np.append(scale, UNREPRESENTED_MODALITY_NAME)

        partial_dep = partial_dep - self.base_prediction
        partial_dep = partial_dep.transpose()

        if self.prediction_type in {doctor_constants.REGRESSION, doctor_constants.BINARY_CLASSIFICATION}:
            # Always use a 2D array
            partial_dep = partial_dep[np.newaxis]

        return PartialDependencyResult(pd_feature, scale, distribution, partial_dep,
                                       indices_to_drop=indices_to_drop,
                                       unrepresented_modalities=unrepresented_modalities)

    def _get_computations_to_make_categorical(self, indices_to_drop, pd_feature, scale, unrepresented_modalities,
                                              unrepresented_pd_value):
        computations_to_make = np.empty(0, dtype='object')
        computations_to_make_indexes = []
        partial_dep_dict = {}
        for index, value in enumerate(scale):
            if value == doctor_constants.FILL_NA_VALUE and pd_feature.drop_missing:
                # All rows will be dropped, no prediction can be computed
                indices_to_drop.append(index)
                # Arbitrary value used here, should be dropped in the front
                pd_value = self.base_prediction
            else:
                # If the modality is not known by the model we know for sure that its partial dependence
                # has the same value as the unrepresented_pd_value, no need for a another computation
                if pd_feature.is_represented(value):
                    if value == doctor_constants.FILL_NA_VALUE:
                        value = np.nan
                    computations_to_make = np.append(computations_to_make, value)
                    computations_to_make_indexes.append(index)
                    continue
                else:
                    unrepresented_modalities.append(value)
                    pd_value = unrepresented_pd_value

                if pd_value is None:
                    # Arbitrary value used here, should be dropped in the front
                    indices_to_drop.append(index)
                    pd_value = self.base_prediction
            partial_dep_dict[index] = pd_value
        return computations_to_make, computations_to_make_indexes, partial_dep_dict

    def aggregate_less_frequent_values(self, partial_dep, scale, distribution):
        new_scale = np.concatenate((scale[:self.max_cats], np.asarray([OTHERS_NAME])))

        distribution_to_keep = distribution[:self.max_cats]
        distribution_to_aggregate = distribution[self.max_cats:]

        partial_dep_to_keep = partial_dep[:self.max_cats]
        partial_dep_to_aggregate = partial_dep[self.max_cats:]

        aggregated_distribution = np.zeros((self.max_cats + 1))
        aggregated_distribution[:self.max_cats] = distribution_to_keep
        aggregated_distribution[-1] = np.sum(distribution_to_aggregate)

        shape = (self.max_cats + 1, partial_dep.shape[1]) if partial_dep.ndim == 2 else (self.max_cats + 1)
        aggregated_partial_dep = np.zeros(shape)
        aggregated_partial_dep[:self.max_cats] = partial_dep_to_keep
        aggregated_partial_dep[-1] = np.average(partial_dep_to_aggregate, axis=0, weights=distribution_to_aggregate)

        return aggregated_partial_dep, new_scale, aggregated_distribution

    def _predict_and_get_pd_value_for_unrepresented_modality(self, pd_feature):
        # Unrepresented modality is a modality that the model doesn't know (not in the train set
        # or too many modalities exist and this one has been discarded.
        # This function compute the partial dependence for a such modality
        # with a fake one named ${UNREPRESENTED_MODALITY_NAME}
        return self._predict_and_get_pd_value_for(pd_feature.name, UNREPRESENTED_MODALITY_NAME)

    def _predict_and_get_pd_value_for(self, col_name, value):
        df_copy = self.dataframe.copy()
        df_copy[col_name] = value
        return self._predict_and_get_pd_value(df_copy)

    def _predict_and_get_pd_value(self, df):
        pred = self.model_handler.predict(df, output_probas=True)
        return self._get_pd_value(pred)

    def _predict_by_batch(self, input_df, progress=None):
        """
        Split df to predict by batch of size batch_size
        :param pd.DataFrame input_df: to predict
        :param function progress: a callback method that will update the progress, taking values from 0 to 1
        """
        all_df_size = input_df.memory_usage(deep=True).sum()
        logger.info("Starting predict by batch with df shape {}".format(input_df.shape))
        pred_df_list = []
        nb_batches = 1 + all_df_size // self.batch_size
        logger.info("Number of batch: {} ".format(str(nb_batches)))

        # Predict by batch
        for i, batch in enumerate(np.array_split(input_df, nb_batches)):
            pred_df_list.append(self.model_handler.predict(batch).reindex(batch.index))
            if progress is not None:
                progress((i + 1) / nb_batches)

        return pd.concat(pred_df_list).reindex(input_df.index)

    def _get_pd_value(self, pred):
        clip_min = 0.01
        clip_max = 0.99

        if pred.empty:
            return None
        else:
            if self.sample_weights is not None:
                # remove rows of weights that could have been dropped by the preprocessing
                weights = self.sample_weights[pred.index]
            else:
                weights = None
            if self.prediction_type == doctor_constants.REGRESSION:
                pd_values = pred["prediction"].values
            elif self.prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                prob_cols = pred.values[:, 2]
                pd_values = log_odds(prob_cols, clip_min=clip_min, clip_max=clip_max)
            elif self.prediction_type == doctor_constants.MULTICLASS:
                pred_filter = pred.filter(["prediction", OVERRIDE_INFO_COL])
                prob_cols = pred.drop(pred_filter, axis="columns").values
                pd_values = log_odds(prob_cols, clip_min=clip_min, clip_max=clip_max)
            else:
                raise ValueError("The prediction type '{}' is not supported for "
                                 "Partial dependence computation".format(self.prediction_type))

            if weights is not None:
                if len(pd_values.shape) == 1:
                    total_sum = np.nansum(pd_values * weights)
                else:
                    total_sum = np.nansum(pd_values * weights[:, np.newaxis], axis=0)
                return total_sum / np.nansum(weights)
            return np.nanmean(pd_values, axis=0)


class PartialDependencyFeature:
    def __init__(self, feature_type, name, is_dummified=False, dummified_modalities=None, drop_missing=False):
        self.type = feature_type
        self.name = name
        self._dummified_modalities = dummified_modalities
        self.is_dummified = is_dummified
        self.drop_missing = drop_missing

    def is_represented(self, value):
        """
        Returns True if the column is not dummified, else it checks if the value/modality
        is known by the model, e.g. if the preprocessing dummify this modality
        :param value: modality of the feature
        :return: boolean
        """
        if self.is_dummified:
            return value in self._dummified_modalities
        else:
            # Return always True in this case for now, can be more clever depending on the feature handling
            return True


class PartialDependencyResult:
    def __init__(self,
                 pd_feature,
                 scale,
                 distribution,
                 partial_dependence,
                 indices_to_drop=None,
                 unrepresented_modalities=None):
        self.feature = pd_feature
        self.scale = scale
        self.distribution = distribution
        self.partial_dependence = partial_dependence
        self.indices_to_drop = indices_to_drop
        self.unrepresented_modalities = unrepresented_modalities


class PartialDependenciesSaver:
    def __init__(self, folder_context, schema):
        self.folder_context = folder_context
        self.dtypes = {}
        for col in schema["columns"]:
            self.dtypes[col["name"]] = col["type"]

    def save(self, pd_result, on_sample, nb_records, random_state):
        iperf = self.folder_context.read_json("iperf.json")

        if "partialDependencies" not in iperf:
            iperf["partialDependencies"] = []

        for partial_dep in iperf["partialDependencies"]:
            if partial_dep.get('feature') == pd_result.feature.name:
                iperf["partialDependencies"].remove(partial_dep)
                break

        new_partial_dependence = {
            "data": list(pd_result.partial_dependence),
            "feature": pd_result.feature.name,
            "distribution": pd_result.distribution,
            "computedPostTraining": True,
            "isDate": self.dtypes[pd_result.feature.name] in ["date", "dateonly", "datetimenotz"],
            "unrepresentedModalities": pd_result.unrepresented_modalities,
            "nbRecords": nb_records,
            "onSample": on_sample,
            "randomState": random_state
        }

        if pd_result.indices_to_drop is not None:
            new_partial_dependence["indicesToDrop"] = pd_result.indices_to_drop

        if pd_result.feature.type == 'CATEGORY':
            new_partial_dependence["categories"] = list(pd_result.scale)
        elif pd_result.feature.type == 'NUMERIC':
            new_partial_dependence["featureBins"] = list(pd_result.scale)

        iperf["partialDependencies"].append(new_partial_dependence)
        self.folder_context.write_json("iperf.json", iperf)
        return iperf
