import logging
import statistics
from collections import defaultdict
from time import time

import numpy as np
import pandas as pd

from dataiku.base.utils import package_is_at_least
from dataiku.core import doctor_constants
from dataiku.core.dku_logging import LogLevelContext
from dataiku.doctor.timeseries.train.split_handler import TimeseriesSingleHorizonDefaultSplitHandler
from dataiku.doctor.timeseries.utils import timeseries_iterator, get_dataframe_of_timeseries_identifier, \
    SINGLE_TIMESERIES_IDENTIFIER
logger = logging.getLogger(__name__)

class PermutationImportanceComputer:
    """
    Computes the permutation feature importance for a time series forecasting model.
    """

    def __init__(self, model_scorer, folder_context,
                 algorithm, core_params, full_timeseries_preprocessing, timeseries_identifier_columns, estimator, input_columns,
                 shift_map, use_only_generated_features):
        """
        Initializes the PermutationImportanceComputer.

        :param model_scorer: Scorer object with an 'evaluate' method.
        :param folder_context: Context for file operations.
        :param algorithm: Algorithm object defining model type.
        :param core_params: Core parameters including 'target_variable' and 'timeVariable'.
        :param full_timeseries_preprocessing: Preprocessing pipeline handler.
        :param timeseries_identifier_columns: Columns defining individual time series.
        :param estimator: The trained forecasting model.
        :param input_columns: List of all feature column names passed to the model.
        :param shift_map: The ShiftMap of the model based on the feature generation preprocessing_params.
        :param use_only_generated_features: Either the model uses only features from feature generation ignoring parent features.
        """
        self.rng = np.random.RandomState(1337)
        self.model_scorer = model_scorer
        self.algorithm = algorithm
        self.timeseries_identifier_columns = timeseries_identifier_columns
        self.full_timeseries_preprocessing = full_timeseries_preprocessing
        self.folder_context = folder_context
        self.split_handler = TimeseriesSingleHorizonDefaultSplitHandler(core_params)
        self.target_variable = core_params["target_variable"]
        self.time_variable = core_params["timeVariable"]
        self.estimator = estimator
        # ensure column order does not affect rng
        self.input_columns = input_columns.copy()
        self.input_columns.sort()

        self.shift_map = shift_map
        self.use_only_generated_features = use_only_generated_features

        self._initialise_all_importances()
        self.must_always_reprocess = False
        self.base_historical_df = None
        self.permuted_df = None

    @staticmethod
    def supports_permutation_importance(algorithm, preprocessing_params):
        supports_external_features = algorithm.EXTERNAL_FEATURES_COMPATIBILITY.supports_external_features()
        has_compatible_external_features = False
        has_incompatible_external_features = False
        for _, feature in preprocessing_params["per_feature"].items():
            if (feature["role"] == "INPUT"
                    or (feature["role"] == "INPUT_PAST_ONLY" and algorithm.EXTERNAL_FEATURES_COMPATIBILITY.supports_past_only_external_features())):
                if feature["type"] in ["NUMERIC", "CATEGORY"]:
                    has_compatible_external_features = True
                if feature["type"] == "TEXT":
                    # text handling often generates massive amounts of new columns, not currently supported
                    has_incompatible_external_features = True
                    break

        return (supports_external_features
                and has_compatible_external_features
                and not has_incompatible_external_features)

    def _initialise_all_importances(self):
        """
        Initializes the 'all_importances' dictionary, attempting to load
        previously computed results from disk if available.
        """
        self.all_importances = {}
        if self.folder_context.isfile("permutation_importance.json"):
            self.all_importances = self.folder_context.read_json("permutation_importance.json")

    def compute_permutation_importance(self, full_df, per_identifier, n_iterations):
        """
        Main entry point for computing permutation importance.

        Depending on `per_identifier`, it computes importance either globally
        across all series or individually for each time series ID.

        :param pd.Dataframe full_df: The complete raw dataset (pd.DataFrame).
        :param bool per_identifier: If True, computes importance for each unique
                               series identifier individually (bool).
        :param int n_iterations: The number of times to shuffle and re-evaluate each feature (int).
        """
        start_time = time()
        logger.info("Starting permutation importance")

        preprocess_on_full_df = self.algorithm.ONE_MODEL_FOR_MULTIPLE_TS
        with LogLevelContext(logging.CRITICAL,
                             doctor_constants.PREPROCESSING_RELATED_LOGGER_NAMES):
            full_df_processed = self.full_timeseries_preprocessing.fit_and_process(
                full_df,
                None,
                preprocess_on_full_df,
                save_data=False,
            )
            pipeline = next(iter(self.full_timeseries_preprocessing.pipeline_by_timeseries.values()))

        if pipeline.generated_features_mapping.mapping:
            # if there are values in this mapping it means that the schema changes
            self.must_always_reprocess = True
        else:
            full_df = full_df_processed

        if per_identifier:
            self.all_importances['perIdentifierPermutationImportance'] = {}
            self.all_importances["nShufflesPerIdentifier"] = n_iterations
            self.all_importances['perIdentifierPermutationImportance'] = self._calc_perm_imp(n_iterations, full_df, True)
        else:
            identifier_importances = self._calc_perm_imp(n_iterations, full_df)
            self.all_importances["nShuffles"] = n_iterations
            self.all_importances["importances"] = identifier_importances

        end_time = time()
        elapsed_time = end_time - start_time
        logger.info("Permutation importance calc took {}s".format(elapsed_time))
        self.folder_context.write_json("permutation_importance.json", self.all_importances)

    def _shuffle_permuted_df(self):
        shuffle_start_time = time()
        self._permute_columns()
        self.permuted_historical_df, self.permuted_test_df, _ = next(self._split_data(self.permuted_df))
        shuffle_end_time = time()
        logger.info("Shuffle took {}".format(shuffle_end_time - shuffle_start_time))

    def _get_permuted_data(self, full_df, column, timeseries_identifier=None):
        """
        Permutes the specified column and returns the processed historical and test dataframes.

        :param pd.DataFrame full_df: The input DataFrame.
        :param str column: The name of the column to shuffle.
        :param str timeseries_identifier: If provided, only keeps the rows matching the identifier.
        :return: A tuple containing (permuted_historical_df, permuted_test_df).
        """
        if self.base_historical_df is None:
            self.base_historical_df, self.base_test_df, _ = next(self._split_data(full_df))

        # deep copy needed for old pandas versions
        deep_copy = not package_is_at_least(pd, "1.4.0")

        historical_df = self.base_historical_df.copy(deep=deep_copy)
        test_df = self.base_test_df.copy(deep=deep_copy)

        preprocessed_feature_names = set()

        if column in self.base_historical_df:
            preprocessed_feature_names.add(column)

        for pipeline in self.full_timeseries_preprocessing.pipeline_by_timeseries.values():
            preprocessed_feature_names.update(pipeline.generated_features_mapping.get_features_from_origin_column(column))

        preprocessed_feature_names.update(self._get_rolling_windows_columns(self.base_historical_df.columns, column))

        for preprocessed_feature_name in sorted(preprocessed_feature_names):
            historical_df[preprocessed_feature_name] = self.permuted_historical_df[preprocessed_feature_name]
            test_df[preprocessed_feature_name] = self.permuted_test_df[preprocessed_feature_name]

        if timeseries_identifier:
            return get_dataframe_of_timeseries_identifier(historical_df, timeseries_identifier), get_dataframe_of_timeseries_identifier(test_df, timeseries_identifier)
        return historical_df, test_df

    def _split_data(self, full_df, per_identifier=False):
        """
        :param pd.DataFrame full_df: The input DataFrame (raw or permuted).
        :param bool per_identifier: If True, creates a data per identifier generator.
        :return: generator of tuples: (historical_df, test_df, timeseries_identifier)
        """
        with LogLevelContext(logging.CRITICAL, doctor_constants.PREPROCESSING_RELATED_LOGGER_NAMES):
            if self.must_always_reprocess:
                preprocess_on_full_df = self.algorithm.ONE_MODEL_FOR_MULTIPLE_TS
                transformed_full_df = self.full_timeseries_preprocessing.fit_and_process(
                    full_df,
                    None,
                    preprocess_on_full_df,
                    save_data=False,
                )
                _, test_df, historical_df = next(self.split_handler.split(transformed_full_df))
            else:
                _, test_df, historical_df = next(self.split_handler.split(full_df))

        if per_identifier and self.timeseries_identifier_columns:
            for timeseries_identifier, _ in timeseries_iterator(
                    full_df, self.timeseries_identifier_columns
            ):
                yield get_dataframe_of_timeseries_identifier(historical_df, timeseries_identifier), get_dataframe_of_timeseries_identifier(test_df, timeseries_identifier), timeseries_identifier
        else:
            yield historical_df, test_df, SINGLE_TIMESERIES_IDENTIFIER

    def _calc_perm_imp(self, iterations, full_df, per_identifier=False):
        """
        Calculates the permutation importance per time series or for the aggregated dataset.

        :param int iterations: Number of permutation runs.
        :param pd.Dataframe full_df: The data slice to evaluate.
        :param boolean per_identifier: Whether it should be computed per identifier or not.
        :return: A list of dictionaries containing the aggregated importance results.
        """
        importances = {}
        self.permuted_df = full_df.copy() # copy the original df only once
        df_per_timeseries = list(self._split_data(full_df, per_identifier=per_identifier))
        for iteration in range(iterations):
            iter_start_time = time()
            logger.info("Permutation importance: Iteration {}".format(iteration))
            self._shuffle_permuted_df()
            for i, (historical_df, test_df, timeseries_identifier) in enumerate(df_per_timeseries):
                if not timeseries_identifier in importances:
                    importances[timeseries_identifier] = defaultdict(list)
                base_score = self.model_scorer.quick_evaluate_mase(self.estimator, historical_df, test_df, test_df)

                for column in self.input_columns:
                    if column == self.target_variable or column == self.time_variable or column in self.timeseries_identifier_columns:
                        continue

                    if self._should_skip_permutation(timeseries_identifier, full_df, column):
                        importances[timeseries_identifier][column].append(0)
                        continue

                    permuted_historical_df, permuted_test_df = self._get_permuted_data(full_df, column, timeseries_identifier)


                    eval_start_time = time()
                    perm_score = self.model_scorer.quick_evaluate_mase(self.estimator, permuted_historical_df, test_df, permuted_test_df)
                    eval_end_time = time()

                    logger.info("Permutation importance - identifier {} - column {} - Eval took {}s".format(timeseries_identifier, column, eval_end_time - eval_start_time))

                    importance = perm_score - base_score

                    importances[timeseries_identifier][column].append(importance)

            iter_end_time = time()
            logger.info("Permutation importance: Iteration {} finished in {}s".format(iteration, iter_end_time-iter_start_time))
        for timeseries_identifier, importance_for_identifier in importances.items():
            importances[timeseries_identifier] = self.aggregate_importances(importance_for_identifier)
        if not per_identifier:
            return importances[SINGLE_TIMESERIES_IDENTIFIER]
        return importances

    def aggregate_importances(self, importances):
        """
        Aggregates the raw importance scores (list of floats per feature) into
        mean and standard deviation, and formats the result as a list of dictionaries.

        :param importances: Dictionary mapping feature name (str) to a list of
                            scores (float) from each iteration (Dict[str, List[float]]).
        :return: A list where each dict contains {'featureName', 'importance' (mean), 'importanceSTD' (std)} (List[Dict[str, Any]]).
        """
        results = []

        for feature_name, scores in importances.items():
            mean_val = sum(scores) / len(scores) if scores else 0

            if len(scores) > 1:
                std_val = statistics.stdev(scores)
            else:
                std_val = 0.0

            results.append({
                'featureName': feature_name,
                'importance': mean_val,
                'importanceSTD': std_val
            })

        return results


    def _permute_columns(self):
        """
        Performs the column shuffling of self.permuted_columns *in place* either globally or
        within each time series identifier group.
        """
        columns = [column for column in self.input_columns if column not in [self.target_variable, self.time_variable] + self.timeseries_identifier_columns]

        if self.timeseries_identifier_columns:
            grouped = self.permuted_df.groupby(self.timeseries_identifier_columns, sort=False)
        else:
            grouped = self.permuted_df.groupby(lambda _: 0) # single group

        for column in columns:
            self.permuted_df[column] = grouped[column].transform(self._shuffle_func)
        for rolling_window_column in self._get_rolling_windows_columns(self.permuted_df.columns):
            if rolling_window_column in self.permuted_df.columns:
                window_length = self._get_rolling_window_length(rolling_window_column)
                start_index = window_length - 1 if window_length > 0 else 0
                self.permuted_df[rolling_window_column] = grouped[rolling_window_column].transform(
                    lambda s: self._shuffle_func(s, start_index)
                )

    def _shuffle_func(self, x, start_index=0):
        idx = np.arange(len(x))
        if start_index < len(idx):
            shuffled_idx = idx[start_index:]
            self.rng.shuffle(shuffled_idx)
            idx[start_index:] = shuffled_idx
        return x.values[idx]

    def _should_skip_permutation(self, timeseries_identifier, full_df, column):
        is_uniform = full_df[column].nunique() == 1
        if is_uniform:
            logger.info("All values for Identifier {} column {} are the same, skipping".format(timeseries_identifier or "overall", column))
            return True
        if self.use_only_generated_features and not self.shift_map.has_shift_or_window(column):
            # Classical ML models do not use the parent feature directly
            # Skip columns that have no shift or rolling window associated
            logger.info("Feature {} does not have associated generated features, skipping".format(column))
            return True
        return False

    @staticmethod
    def _get_rolling_windows_columns(df_columns_names, column=None):
        """
        Returns the rolling windows derived from the given column if it's defined.
        Returns all of them in the input list otherwise.
        :param df_columns_names: list[str]: All column names
        :param column: str: Column name to filter on
        :return: list[str]: List of rolling windows
        """
        rolling_windows_columns = []
        for df_column_name in df_columns_names:
            if not df_column_name.startswith("rolling_window:"):
                continue
            # Pattern: rolling_window:{length}:{operation}:{column} or rolling_window:{length}:{operation}:{column}:{category}
            # Use maxsplit=3 to preserve colons in column names
            parts = df_column_name.split(":", 3)
            if len(parts) == 4:
                if column is None:
                    rolling_windows_columns.append(df_column_name)
                    continue
                rest = parts[3]
                # rest is "{column}" or "{column}:{category}"
                if rest == column or rest.startswith(column + ":"):
                    rolling_windows_columns.append(df_column_name)
        return rolling_windows_columns

    @staticmethod
    def _get_rolling_window_length(rolling_window_column):
        """Extract window length from rolling_window:{length}:{operation}:{column} pattern."""
        parts = rolling_window_column.split(":")
        if len(parts) >= 2 and parts[0] == "rolling_window":
            return int(parts[1])
        return 0