import logging
import os.path as osp

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor

from dataiku.core import dkujson

logger = logging.getLogger(__name__)


class ColumnImportanceHandler:
    """ Class to handle creation, saving and fetching of column importance for a model.
    Column importance should be distinguished from feature importance. Feature importance gives the importance
    of the preprocessed features (dummified colums, linear combinations...) while column importance
    gives the importance of the model input column (e.g. before pre-processing)
    """
    COLUMN_IMPORTANCE_FILENAME = "column_importance.json"
    MAX_ALLOWED_FEATURES = 100000

    def __init__(self, model_folder_context, preprocessing_folder_context):
        self.model_folder_context = model_folder_context
        self.preprocessing_folder_context = preprocessing_folder_context

    def has_saved_column_importance(self):
        return self.model_folder_context.isfile(self.COLUMN_IMPORTANCE_FILENAME)

    def get_column_importance(self):
        """ Load the column importance file in the model folder"""
        return pd.DataFrame(self.model_folder_context.read_json(self.COLUMN_IMPORTANCE_FILENAME))

    def compute_column_importance(self, all_columns, features, prepared_array, y_true, generated_features_mapping, save):
        features, importances = self._build_preprocessed_feature_importance(features, prepared_array, y_true)
        column_importance = self.compute_column_importance_from_feature_importance(features, importances, all_columns,
                                                                                   generated_features_mapping)
        if save:
            logger.info("Saving column importance")
            self.model_folder_context.write_json(self.COLUMN_IMPORTANCE_FILENAME,
                                                 column_importance.to_dict(orient="list"))
        return column_importance

    def _build_preprocessed_feature_importance(self, preprocessed_features, preprocessed_array, y_true):
        """
        Build or retrieve preprocessed features importances

        :type preprocessed_features: list
        :type preprocessed_array: numpy.ndarray | scipy.sparse.csr_matrix
        :type y_true: numpy.ndarray
        :rtype: (list, np.ndarray)
        """
        raw_model_importance = self.model_folder_context.read_json("iperf.json").get("rawImportance")
        if raw_model_importance is not None:
            logger.info("Reusing the model feature importance")
            preprocessed_features = raw_model_importance["variables"]
            preprocessed_feature_importances = np.array(raw_model_importance["importances"])
        else:
            preprocessed_feature_importances = ColumnImportanceHandler._compute_feature_importance(preprocessed_array,
                                                                                                   y_true)

        # Check if PCA was applied. In that case:
        # * preprocessed_feature_importances will correspond to the importance of [pca_component_1, ...]
        # * we cannot directly trace back the importance of the input columns from those importances
        # * therefore, we re-build intermediate importances [preprocessed_feature1, ...] from those that will then
        #   allow to get the column importances. This is done by projecting the principal components importances in the
        #   preprocessed features space, which is an approximation
        if self.preprocessing_folder_context.isfile("feature_selection.json"):
            feature_selection = self.preprocessing_folder_context.read_json("feature_selection.json")
            if feature_selection.get("method") == "PCA":
                rot = feature_selection.get("selection_params", {}).get("rot")
                input_names = feature_selection.get("selection_params", {}).get("input_names")
                if rot and input_names:
                    logger.info("Building preprocessed feature importances from principal components importances")
                    # In that case:
                    #   * We project the importances back in the preprocessed feature space
                    #   * We take the absolute value of the components' matrix in order to avoid different components
                    #     with opposite signs to cancel each other
                    #   * And then rescale them to sum to 1
                    components = np.array(rot)
                    preprocessed_feature_importances = np.matmul(np.abs(components), preprocessed_feature_importances)
                    preprocessed_feature_importances /= preprocessed_feature_importances.sum()
                    preprocessed_features = input_names

        return preprocessed_features, preprocessed_feature_importances

    @staticmethod
    def compute_column_importance_from_feature_importance(preprocessed_features, preprocessed_feature_importances,
                                                          all_columns, generated_features_mapping):
        """
        Compute column importance from feature importance.
        Then, assign to each column the sum of the importance of the features created from that column.

        :type preprocessed_features: list
        :type preprocessed_feature_importances: np.ndarray
        :type all_columns: list
        :type generated_features_mapping: dataiku.doctor.preprocessing.generated_features_mapping.GeneratedFeaturesMapping
        :rtype: pd.DataFrame
        """
        column_importances = {}
        for col in all_columns:
            column_importances[col] = 0.
        for (feature,  importance) in zip(preprocessed_features, preprocessed_feature_importances):
            if np.isnan(importance):
                logger.warning("Feature importance for '{}' is NaN".format(feature))
                continue
            try:
                columns = generated_features_mapping.get_origin_columns_from_feature(feature, all_columns)
                for column in columns:
                    column_importances[column] += importance / len(columns)
            except Exception as e:
                # Raising error here is acceptable, not all feature generation mechanisms are supported:
                raise ValueError("Unsupported feature generation mechanism {} for feature: {}".format(feature, e))
        return pd.DataFrame({"columns": list(column_importances.keys()),
                             "importances": list(column_importances.values())})

    @staticmethod
    def _compute_feature_importance(preprocessed_array, target):
        # TODO: with sckit 0.22 - Replace feature importances method by permutation importance
        logger.info("Computing feature importances")

        na_mask = pd.isna(target) if target.ndim == 1 else pd.isna(target[:, 0])
        preprocessed_no_na = preprocessed_array[~na_mask]
        target_no_na = target[~na_mask]

        nb_rows = preprocessed_no_na.shape[0]
        idx = np.random.RandomState(1337).choice(nb_rows, min(1000, nb_rows), replace=False)
        clf = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=0)
        clf.fit(preprocessed_no_na[idx], target_no_na[idx])
        return clf.feature_importances_
