import logging

import numpy as np
from statsmodels.tsa.stattools import acf

from dataiku.doctor.preprocessing_handler import write_resource
from dataiku.doctor.timeseries.utils import timeseries_iterator

logger = logging.getLogger(__name__)

class TimeseriesAutoShiftsGenerator:
    RESOURCE_NAME = "timeseries_auto_shifts_generation"
    RESOURCE_TYPE = "json"

    def __init__(self, data_folder_context,
                 timeseries_identifiers_columns,
                 target_column,
                 past_only_columns,
                 known_in_advance_columns,
                 past_only_range,
                 known_in_advance_range,
                 horizon_length,
                 max_selected_shifts,
                 ):

        self.data_folder_context = data_folder_context
        self.timeseries_identifiers_columns = timeseries_identifiers_columns
        self.shifts_columns = past_only_columns + known_in_advance_columns
        self.target_column = target_column
        self.past_only_columns = past_only_columns
        self.known_in_advance_columns = known_in_advance_columns
        self.past_only_range = past_only_range
        self.known_in_advance_range = known_in_advance_range
        self.horizon_length = horizon_length
        self.max_selected_shifts = max_selected_shifts

        self.json_data = {
            "aggregated": {}
        }
        self.cf_weights_array_per_columns = {}
        self.cf_values_array_per_columns = {}
        self.cf_shifts_array_per_columns = {}
        for shift_column in self.shifts_columns:
            self.cf_weights_array_per_columns[shift_column] = []
            self.cf_values_array_per_columns[shift_column] = []
            self.cf_shifts_array_per_columns[shift_column] = []


    def process(self, input_df):
        """
        Guess and optimize shifts for each shifts_columns on each timeseries identifiers,
        then store computation results in a resource to be used later in the pipeline.
        """
        logger.info("Start auto-shift computation for {} features".format(len(self.shifts_columns)))
        self._validate_params()
        self._compute_correlation(input_df)
        self._aggregate_correlations_and_select_shifts()
        self._write_auto_shifts_resource()
        logger.info("End auto-shift computation")

    def _validate_params(self):
        if self.max_selected_shifts < 1:
            raise ValueError("Auto-shifts params - Max number of selected shifts must be greater than 0.")

        past_only_range_length = self.past_only_range["max"] - self.past_only_range["min"]
        if past_only_range_length < 1:
            raise ValueError("Auto-shifts params - Past only features range length must be positive.")

        known_in_advance_range_length = self.known_in_advance_range["max"] - self.known_in_advance_range["min"]
        if known_in_advance_range_length < 1:
            raise ValueError("Auto-shifts params - Known in advance features range length must be positive.")

        past_only_range_max_valid_value = -self.horizon_length
        if self.past_only_range["max"] > past_only_range_max_valid_value:
            raise ValueError("Auto-shifts params - Past only features range end from forecasted point must be smaller or equal to -{} (prediction length).".format(self.horizon_length))

        known_in_advance_range_max_valid_value = 0 # no support for leads
        if self.known_in_advance_range["max"] > known_in_advance_range_max_valid_value:
            raise ValueError("Auto-shifts params - Known in advance features range end from forecasted point must be negative or zero.")

    def _compute_correlation(self, input_df):
        for shift_column in self.shifts_columns:

            if shift_column == self.target_column or shift_column in self.past_only_columns:
                min_shift = self.past_only_range["min"]
                max_shift = self.past_only_range["max"]
            else:
                min_shift = self.known_in_advance_range["min"]
                max_shift = self.known_in_advance_range["max"]

            shifts = list(range(min_shift, max_shift + 1)) if min_shift <= max_shift else []
            self.cf_shifts_array_per_columns[shift_column] = shifts

            for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(
                    input_df, self.timeseries_identifiers_columns
            ):
                if shift_column == self.target_column:
                    cf_values = TimeseriesAutoShiftsOptimizer.compute_acf(
                        df_of_timeseries_identifier[self.target_column],
                        min_shift,
                        max_shift)
                else:
                    cf_values = TimeseriesAutoShiftsOptimizer.compute_ccf(
                        df_of_timeseries_identifier[self.target_column],
                        df_of_timeseries_identifier[shift_column],
                        min_shift,
                        max_shift,
                    )

                self.cf_values_array_per_columns[shift_column].append(cf_values)
                self.cf_weights_array_per_columns[shift_column].append(len(df_of_timeseries_identifier))

    def _aggregate_correlations_and_select_shifts(self):

        for shift_column in self.shifts_columns:
            column_cf_values_aggregated = TimeseriesAutoShiftsOptimizer.aggregate_cf(
                self.cf_values_array_per_columns[shift_column],
                self.cf_weights_array_per_columns[shift_column])

            column_selected_shifts = TimeseriesAutoShiftsOptimizer.select_most_correlated_shifts(
                column_cf_values_aggregated,
                self.cf_shifts_array_per_columns[shift_column],
                max_selected_shifts=self.max_selected_shifts,
            )

            self.json_data["aggregated"][shift_column] = {
                "selected_shifts": column_selected_shifts,
                "correlation_values": column_cf_values_aggregated,
                "correlation_shifts": self.cf_shifts_array_per_columns[shift_column],
            }

    def _write_auto_shifts_resource(self):
        write_resource(self.data_folder_context, self.RESOURCE_NAME, self.RESOURCE_TYPE, self.json_data)

    @staticmethod
    def load_auto_shifts_resource(data_folder_context):
        return data_folder_context.read_json(
            TimeseriesAutoShiftsGenerator.RESOURCE_NAME + "." + TimeseriesAutoShiftsGenerator.RESOURCE_TYPE)


class TimeseriesAutoShiftsOptimizer:
    """
    Automatically selects significant shifts for a time series using auto-correlation and cross-correlation analysis.
    """

    @staticmethod
    def compute_acf(series_values, min_shift, max_shift):
        """
        Compute the auto-correlation function (ACF) values for a time series.

        Parameters
        ----------
        series_values : pandas.Series
            Time series values (must already be resampled to a uniform timestep).
        min_shift : int
            Left/Start boundary for the shifts range to consider when computing auto-correlations, it is included in the result.
            Must be smaller or equal to -1.
        max_shift : int
            Right/End boundary for the shifts range to consider when computing auto-correlations, it is included in the result.
            Must be smaller or equal to -1, and larger than or equal to `min_shift`

        Returns
        -------
        acf_values : numpy.ndarray
            A 1D array of auto-correlation values between min_shift and max_shift.
            If the series is too short, the array is zero-padded with NaNs up to length (max_shift - min_shift + 1).
        """

        acf_values = []

        shift_range_length = max_shift - min_shift + 1

        if len(series_values) > 1 and any(series_values):
            nlags = -min_shift
            acf_values = acf(series_values, nlags=nlags, fft=True)
            if max_shift < 0:
                # Remove unnecessary entries as acf compute auto correlation from 0 to nlags (-min_shift)
                # Keep values between -max_shift and -min_shift
                acf_values = acf_values[-max_shift:]

        acf_values = TimeseriesAutoShiftsOptimizer.sanitize_correlation_values(acf_values, shift_range_length)

        # Reverse array to follow natural order
        acf_values = acf_values[::-1]

        return acf_values

    @staticmethod
    def compute_ccf(target_values, feature_values, min_shift, max_shift):
        """
        Compute the cross-correlation function (CCF) values for a time series, between target and external feature.

        Parameters
        ----------
        target_values : pandas.Series
            Time series target values (must already be resampled to a uniform timestep).
        feature_values : pandas.Series
            Time series feature values (must already be resampled to a uniform timestep).
        min_shift : int
            Left/Start boundary for the shifts range to consider when computing cross-correlations, it is included in the result.
        max_shift : int
            Right/End boundary for the shifts range to consider when computing cross-correlations, it is included in the result.

        Returns
        -------
        ccf_values : numpy.ndarray
            A 1D array of cross-correlation values from min_shift to max_shift
            If the series is too short, the array is zero-padded with NaNs up to length of the shift interval + 1.
        """

        ccf_values = np.array([])
        ccf_shifts = list(range(min_shift, max_shift + 1))

        if len(target_values) >= 0 and len(feature_values) >= 0 and min_shift <= max_shift:

            for shift in ccf_shifts:
                shifted_feature = feature_values.shift(-shift)
                correlation = target_values.corr(shifted_feature)
                ccf_values = np.append(ccf_values, correlation)

        ccf_values = TimeseriesAutoShiftsOptimizer.sanitize_correlation_values(ccf_values, max_shift - min_shift + 1)

        return ccf_values

    @staticmethod
    def sanitize_correlation_values(correlation_values, shifts_range_length):

        if len(correlation_values) < shifts_range_length:
            # Pad values to always have at least shifts_range_length elements
            correlation_values = np.pad(correlation_values, (0, shifts_range_length - len(correlation_values)),
                                        mode='constant',
                                        constant_values=np.nan)
        return correlation_values

    @staticmethod
    def aggregate_cf(cf_values_array, cf_weight_array):
        """
        Compute the weighted average of multiple correlation function value arrays, ignoring NaNs.

        Parameters
        ----------
        cf_values_array : list of array-like
            List of correlation function value arrays, each corresponding to a different time series group.
        cf_weight_array : array-like
            1D array of weights, one per correlation function array, used in the aggregation.

        Returns
        -------
        aggregated_cf : numpy.ndarray
            A 1D array representing the weighted average correlation function values at each shift.
            If all values are NaN at a shift index, the result is 0 for that shift.
        """

        # Use absolute values during aggregation to select values strongly correlated regardless of sign
        values = np.absolute(np.array(cf_values_array))
        weights = np.array(cf_weight_array).reshape(-1, 1)

        # Masked nan values, replace them with 0 and weight them 0
        mask = ~np.isnan(values)
        values_masked = np.nan_to_num(values)
        weights_masked = np.where(mask, weights, 0.0)

        # Compute weighted average
        weighted_sum = np.sum(values_masked * weights_masked, axis=0)
        sum_weights = np.sum(weights_masked, axis=0)

        # Avoid divide by zero to support NaN columns
        return np.divide(weighted_sum, sum_weights, out=np.zeros_like(weighted_sum), where=sum_weights != 0)

    @staticmethod
    def select_most_correlated_shifts(cf_values, cf_shifts, selection_threshold=0.25, max_selected_shifts=10):
        """
        Select the most correlated shifts based on correlation values.

        Parameters
        ----------
        cf_values : numpy.ndarray
            1D array of correlation function values (output of compute_acf, compute_ccf, aggregate_acf or aggregate_ccf).
            Values are expected to be between -1.0 and 1.0.
        cf_shifts : list of int
            A 1D array of shift indices matching cf_values.
        selection_threshold: float
            Percentage of correlation required to consider a shift for selection. Default to 0.25.
        max_selected_shifts: int
            Maximum number of shifts matching criteria that can be selected. Default to 10.

        Returns
        -------
        selected_shifts : list of int
            List of shift indices where the absolute correlation value exceeds
            'selection_threshold', sorted by descending absolute correlation value.
            At most 'max_selected_shifts' are returned.
        """

        selected_candidates = []

        # Prepare a list of (shift, absolute_score) tuples
        selectable_shifts = []
        for shift, val in zip(cf_shifts, cf_values):
            selectable_shifts.append((shift, abs(val)))

        # Filter for main candidates that meet the correlation threshold.
        for shift, score in selectable_shifts:
            if score > selection_threshold:
                selected_candidates.append((shift, score))

        # Sort these candidates by score to find the best ones.
        selected_candidates.sort(key=lambda x: x[1], reverse=True)

        # Keep only up to max selected shifts
        selected_candidates = selected_candidates[:max_selected_shifts]

        # Extract and return just the shift values from the final sorted list.
        final_shifts = [shift for shift, _ in selected_candidates]

        return final_shifts
