# -*- coding: utf-8 -*-
import logging
import pandas as pd
from scipy import interpolate

from dataiku.doctor.timeseries.preparation.resampling.dataframe_helpers import filter_columns_without_enough_values
from dataiku.doctor.timeseries.preparation.resampling.dataframe_helpers import generic_check_compute_arguments
from dataiku.doctor.timeseries.preparation.resampling.utils import generate_date_range, supports_monthly_alignment
from dataiku.doctor.timeseries.utils import encode_timeseries_identifier, _groupby_compat
from dataiku.doctor.timeseries.utils import pretty_timeseries_identifiers


logger = logging.getLogger(__name__)

# To keep in sync with TimeseriesForecastingPreprocessingParams.java: TimeseriesSamplingParams.TimeseriesImputeMethod
SUPPORTED_INTERPOLATION_METHODS = {"NEAREST", "PREVIOUS", "NEXT", "LINEAR", "QUADRATIC", "CUBIC", "CONSTANT", "STAIRCASE"}
SUPPORTED_EXTRAPOLATION_METHODS = {"PREVIOUS_NEXT", "NO_EXTRAPOLATION", "CONSTANT", "LINEAR", "QUADRATIC", "CUBIC"}
SUPPORTED_CAT_IMPUTATION_METHODS = {"MOST_COMMON", "NULL", "CONSTANT", "PREVIOUS_NEXT", "PREVIOUS", "NEXT"}
SUPPORTED_DUPLICATE_TIMESTAMPS_HANDLING_METHODS = {"FAIL_IF_CONFLICTING", "DROP_IF_CONFLICTING", "MEAN_MODE"}
SUPPORTED_BOUNDARY_DATES_METHODS = ["AUTO", "CUSTOM"]


class Resampler:
    def __init__(
        self,
        interpolation_method="LINEAR",
        extrapolation_method="PREVIOUS_NEXT",
        interpolation_constant_value=None,
        extrapolation_constant_value=None,
        category_imputation_method="PREVIOUS_NEXT",
        category_constant_value=None,
        time_step=1,
        time_unit="SECOND",
        time_unit_end_of_week="SUN",
        duplicate_timestamps_handling_method="FAIL_IF_CONFLICTING",
        unit_alignment=None,
        monthly_alignment=None,
        start_date_mode="AUTO",
        custom_start_date=None,
        end_date_mode="AUTO",
        custom_end_date=None,
    ):
        if interpolation_method not in SUPPORTED_INTERPOLATION_METHODS:
            raise ValueError("Interpolation method '{}' is not supported. Should be one of: {}".format(interpolation_method, SUPPORTED_INTERPOLATION_METHODS))
        if extrapolation_method not in SUPPORTED_EXTRAPOLATION_METHODS:
            raise ValueError("Extrapolation method '{}' is not supported. Should be one of: {}".format(extrapolation_method, SUPPORTED_EXTRAPOLATION_METHODS))
        if category_imputation_method not in SUPPORTED_CAT_IMPUTATION_METHODS:
            raise ValueError("Non-numerical imputation method '{}' is not supported. Should be one of: {}".format(category_imputation_method,
                                                                                                                  SUPPORTED_CAT_IMPUTATION_METHODS))
        if duplicate_timestamps_handling_method not in SUPPORTED_DUPLICATE_TIMESTAMPS_HANDLING_METHODS:
            raise ValueError("Method to handle conflicting duplicate timestamps '{}' is not supported. Should be one of: {}".format(duplicate_timestamps_handling_method, SUPPORTED_DUPLICATE_TIMESTAMPS_HANDLING_METHODS))

        # check resampling dates
        if not start_date_mode:  # backward compatibility
            start_date_mode = "AUTO"
        if start_date_mode not in SUPPORTED_BOUNDARY_DATES_METHODS:
            raise ValueError("Resampling start date mode '{}' is not supported. Should be one of: {}".format(start_date_mode, SUPPORTED_BOUNDARY_DATES_METHODS))
        if start_date_mode == "CUSTOM" and custom_start_date is None:
            raise ValueError("Resampling start date mode is set to CUSTOM but the custom start date is unset")
        if not end_date_mode:  # backward compatibility
            end_date_mode = "AUTO"
        if end_date_mode not in SUPPORTED_BOUNDARY_DATES_METHODS:
            raise ValueError("Resampling end date mode '{}' is not supported. Should be one of: {}".format(end_date_mode, SUPPORTED_BOUNDARY_DATES_METHODS))
        if end_date_mode == "CUSTOM" and custom_end_date is None:
            raise ValueError("Resampling end date mode is set to CUSTOM but the custom end date is unset")

        # check period selected monthly alignment
        if supports_monthly_alignment(time_unit) and monthly_alignment is not None:
            if monthly_alignment == 0:  # backward compatibility
                monthly_alignment = 31
            elif monthly_alignment < 1 or monthly_alignment > 31:
                raise ValueError("Period selected day '{}' is invalid. It must be in [1, 31]".format(monthly_alignment))

        # check period selected unit alignment
        if time_unit == "QUARTER" and unit_alignment is not None and (unit_alignment < 1 or unit_alignment > 3):
            raise ValueError("Period selected month '{}' is invalid. It must be in [1, 3]".format(unit_alignment))
        if time_unit == "HALF_YEAR" and unit_alignment is not None and (unit_alignment < 1 or unit_alignment > 6):
            raise ValueError("Period selected month '{}' is invalid. It must be in [1, 6]".format(unit_alignment))
        if time_unit == "YEAR" and unit_alignment is not None and (unit_alignment < 1 or unit_alignment > 12):
            raise ValueError("Period selected month '{}' is invalid. It must be in [1, 12]".format(unit_alignment))

        self.interpolation_method = interpolation_method
        self.extrapolation_method = extrapolation_method
        self.interpolation_constant_value = interpolation_constant_value
        self.extrapolation_constant_value = extrapolation_constant_value
        self.category_imputation_method = category_imputation_method
        self.category_constant_value = category_constant_value
        self.start_date_mode = start_date_mode
        self.custom_start_date = custom_start_date
        self.end_date_mode = end_date_mode
        self.custom_end_date = custom_end_date

        self.time_step = time_step
        self.time_unit = time_unit
        self.time_unit_end_of_week = time_unit_end_of_week
        self.unit_alignment = unit_alignment
        self.monthly_alignment = monthly_alignment

        self.duplicate_timestamps_handling_method = duplicate_timestamps_handling_method

    def transform(
        self, df, datetime_column, timeseries_identifier_columns=None, numerical_columns=None, categorical_columns=None
    ):
        timeseries_identifier_columns = timeseries_identifier_columns or []
        numerical_columns = numerical_columns or []
        categorical_columns = categorical_columns or []

        generic_check_compute_arguments(datetime_column, timeseries_identifier_columns)

        # drop all rows where the timestamp is null
        df_no_nan = df.dropna(subset=[datetime_column])
        if len(df_no_nan.index) < 2:
            logger.warning("The timeseries has less than 2 rows with values, cannot resample.")
            return df_no_nan

        # when having multiple timeseries, their time range is not necessarily the same
        # we thus compute a unified time index for all partitions
        reference_time_index = self._compute_full_time_index(df_no_nan[datetime_column])

        if timeseries_identifier_columns:
            df_no_nan_grouped_by_timeseries_identifiers = df_no_nan.groupby(_groupby_compat(timeseries_identifier_columns))
            nb_timeseries = df_no_nan_grouped_by_timeseries_identifiers.ngroups
            resampled_dfs = []
            nb_identifiers = len(timeseries_identifier_columns)
            logger.info("Found {} time series to resample".format(nb_timeseries))
            for timeseries_identifier_values, df_of_timeseries_identifier in df_no_nan_grouped_by_timeseries_identifiers:
                timeseries_identifier = encode_timeseries_identifier(timeseries_identifier_values, timeseries_identifier_columns)
                logger.info("Resampling time series {} of shape: {}".format(
                    pretty_timeseries_identifiers(timeseries_identifier),
                    df_of_timeseries_identifier.shape
                ))
                df_of_timeseries_identifier_resampled = self._resample(
                    df_of_timeseries_identifier.drop(timeseries_identifier_columns, axis=1),
                    datetime_column,
                    numerical_columns,
                    categorical_columns,
                    reference_time_index,
                    timeseries_identifier=timeseries_identifier
                )
                df_of_timeseries_identifier_resampled[timeseries_identifier_columns] = pd.DataFrame([
                    [timeseries_identifier_values] if nb_identifiers == 1 else list(timeseries_identifier_values)
                ], index=df_of_timeseries_identifier_resampled.index)
                resampled_dfs.append(df_of_timeseries_identifier_resampled)
            df_resampled = pd.concat(resampled_dfs, sort=True)
        else:
            logger.info("Resampling a unique time series of shape: {}".format(df_no_nan.shape))
            df_resampled = self._resample(
                df_no_nan, datetime_column, numerical_columns, categorical_columns, reference_time_index
            )

        df_resampled = df_resampled[df.columns].reset_index(drop=True)

        return df_resampled

    def _can_customize_resampling_dates(self):
        # keep in sync with predsettings.js::canCustomizeResamplingDates() and time-series-card-config.component.ts::canCustomizeResamplingDates()
        return self.extrapolation_method != "NO_EXTRAPOLATION"

    def _compute_full_time_index(self, datetime_column):
        """Create the full index of the resampling output dataframe, optionally extended to include custom start and end dates

        Args:
            datetime_column (pd.Series): Series containing all timestamps
        """
        start_time = datetime_column.min()
        end_time = datetime_column.max()
        if self._can_customize_resampling_dates():
            if self.start_date_mode == "CUSTOM" and self.custom_start_date:
                custom_start_date = pd.Timestamp(self.custom_start_date).tz_localize(start_time.timetz().tzinfo)
                if custom_start_date > start_time:
                    logger.info("Custom start date ({}) won't be used since data is known before that date (since {})".format(custom_start_date, start_time))
                else:
                    start_time = custom_start_date

            if self.end_date_mode == "CUSTOM" and self.custom_end_date:
                custom_end_date = pd.Timestamp(self.custom_end_date).tz_localize(end_time.timetz().tzinfo)
                if custom_end_date < end_time:
                    logger.info("Custom extrapolation end date ({}) won't be used since data is known after that date (up to {})".format(custom_end_date, end_time))
                else:
                    end_time = custom_end_date

        return generate_date_range(start_time, end_time, self.time_step, self.time_unit, self.time_unit_end_of_week, self.unit_alignment, self.monthly_alignment)

    def _resample(self, df, datetime_column, numerical_columns, categorical_columns, reference_time_index, timeseries_identifier=None):
        """
        1. Move datetime column to the index.
        2. Merge the original datetime index with the full_time_index.
        3. Create a numerical index of the df and save the correspond index.

        reference_time_index is a DateTimeIndex with all the dates for which we need data in the end - either the data in df, or interpolated.
        """
        if not df[datetime_column].is_unique:
            df = self._remove_duplicates(df, datetime_column, numerical_columns, categorical_columns, timeseries_identifier)

        if len(df.index) < 2:
            logger.warning("The time series {} has less than 2 rows with values, cannot resample.".format(
                pretty_timeseries_identifiers(timeseries_identifier)
            ))
            return df

        # `scipy.interpolate.interp1d` does not like columns with less than 2 valid values, so we cannot resampled them
        filtered_numerical_columns = filter_columns_without_enough_values(df, numerical_columns)

        df_resample = df.set_index(datetime_column).sort_index()
        # merge the reference time index with the original ones that has data
        # at this point, df_resample has as all the dates in need for data as index, and df data for matching dates.
        df_resample = df_resample.reindex(df_resample.index.union(reference_time_index))

        # `scipy.interpolate.interp1d` only works with numerical index, so we create one.
        # We take care that the values for this index are spaced in the same way the dates in the index are - so we just convert those to int representations (timestamps in nanoseconds).
        # Note that the interpolation logic is perfectly fine with negative values, and going from negative to positive, so there's no year 1970 bug !
        df_resample["numerical_index"] = df_resample.index.values.astype(int)
        reference_index = df_resample.loc[reference_time_index, "numerical_index"]
        category_imputation_index = pd.Index([])

        df_resample = df_resample.rename_axis(datetime_column).reset_index()
        df_resample.index = df_resample["numerical_index"]
        for filtered_column in filtered_numerical_columns:

            df_without_nan = df.dropna(subset=[filtered_column])
            interpolation_index_mask = (df_resample[datetime_column] >= df_without_nan[datetime_column].min()) & (
                df_resample[datetime_column] <= df_without_nan[datetime_column].max()
            )
            interpolation_index = df_resample.index[interpolation_index_mask]

            extrapolation_index_mask = (df_resample[datetime_column] < df_without_nan[datetime_column].min()) | (
                df_resample[datetime_column] > df_without_nan[datetime_column].max()
            )
            extrapolation_index = df_resample.index[extrapolation_index_mask]

            index_with_data = df_resample.loc[interpolation_index, filtered_column].dropna().index

            if self.interpolation_method == "CONSTANT":
                df_resample.loc[interpolation_index, filtered_column] = df_resample.loc[
                    interpolation_index, filtered_column
                ].fillna(self.interpolation_constant_value)
            elif self.interpolation_method == "STAIRCASE":
                df_resample[filtered_column]=df_resample[filtered_column].interpolate(method="linear") # Unlike scipy, pandas linear interpolate ignores the index and averages
            else:
                interpolation_function = interpolate.interp1d(
                    index_with_data,
                    df_resample.loc[index_with_data, filtered_column],
                    kind=self.interpolation_method.lower(),
                    axis=0,
                    fill_value="extrapolate",
                )

                df_resample.loc[interpolation_index, filtered_column] = interpolation_function(
                    df_resample.loc[interpolation_index].index
                )

            if self.extrapolation_method in ["LINEAR", "QUADRATIC", "CUBIC"]:
                extrapolation_function = interpolate.interp1d(
                    index_with_data,
                    df_resample.loc[index_with_data, filtered_column],
                    kind=self.extrapolation_method.lower(),
                    axis=0,
                    fill_value="extrapolate",
                )

                df_resample.loc[extrapolation_index, filtered_column] = extrapolation_function(
                    df_resample.loc[extrapolation_index].index
                )

            elif self.extrapolation_method == "CONSTANT":
                df_resample.loc[extrapolation_index, filtered_column] = df_resample.loc[
                    extrapolation_index, filtered_column
                ].fillna(self.extrapolation_constant_value)

            elif self.extrapolation_method == "PREVIOUS_NEXT":
                temp_df = df_resample.ffill().bfill()
                df_resample.loc[extrapolation_index, filtered_column] = temp_df.loc[
                    extrapolation_index, filtered_column
                ]

            elif self.extrapolation_method == "NO_EXTRAPOLATION":
                reference_index = reference_index[~reference_index.isin(extrapolation_index.values)]
            category_imputation_index = category_imputation_index.union(extrapolation_index).union(interpolation_index)

        if len(categorical_columns) > 0 and self.category_imputation_method != "NULL":
            if len(filtered_numerical_columns) == 0:
                # when no numerical columns have been resampled, we don't have a category_imputation_index
                # we need to impute categorical columns on the reference_index
                if self.extrapolation_method == "NO_EXTRAPOLATION":
                    extrapolation_index_mask = (df_resample[datetime_column] < df[datetime_column].min()) | (
                        df_resample[datetime_column] > df[datetime_column].max()
                    )
                    extrapolation_index = df_resample.index[extrapolation_index_mask]
                    reference_index = reference_index[~reference_index.isin(extrapolation_index.values)]
                df_resample = self._fill_in_category_values(df_resample, categorical_columns)
            elif len(category_imputation_index) > 0:
                df_processed = df_resample.loc[category_imputation_index]
                df_resample.loc[category_imputation_index] = self._fill_in_category_values(df_processed, categorical_columns)
        df_resampled = df_resample.loc[reference_index].drop("numerical_index", axis=1)
        return df_resampled

    def _fill_in_category_values(self, df, categorical_columns):
        category_filled_df = df.copy()
        if self.category_imputation_method == "CONSTANT":
            category_filled_df.loc[:, categorical_columns] = category_filled_df.loc[:, categorical_columns].fillna(
                self.category_constant_value
            )
        elif self.category_imputation_method == "PREVIOUS":
            category_filled_df.loc[:, categorical_columns] = category_filled_df.loc[:, categorical_columns].ffill()
        elif self.category_imputation_method == "NEXT":
            category_filled_df.loc[:, categorical_columns] = category_filled_df.loc[:, categorical_columns].bfill()
        elif self.category_imputation_method == "PREVIOUS_NEXT":
            category_filled_df.loc[:, categorical_columns] = category_filled_df.loc[:, categorical_columns].ffill().bfill()
        elif self.category_imputation_method == "MOST_COMMON":
            most_frequent_categoricals = category_filled_df.loc[:, categorical_columns].mode().iloc[0]
            category_filled_df.loc[:, categorical_columns] = category_filled_df.loc[:, categorical_columns].fillna(
                most_frequent_categoricals
            )
        return category_filled_df

    def _remove_duplicates(self, df, datetime_column, numerical_columns, categorical_columns, timeseries_identifier):
        resampling_columns = [datetime_column] + numerical_columns + categorical_columns
        other_columns = [column for column in df.columns if column not in resampling_columns]

        total_rows = len(df.index)

        # drop non-conflicting duplicates except for the first occurences
        df = df.drop_duplicates(subset=resampling_columns, keep="first")

        if self.duplicate_timestamps_handling_method == "FAIL_IF_CONFLICTING":
            if not df[datetime_column].is_unique:
                raise Exception("""The time series {} contains conflicting duplicate timestamps.
                    Try to change/add identifier columns or change the duplicate timestamp handling method.
                    """.format(pretty_timeseries_identifiers(timeseries_identifier)))

        elif self.duplicate_timestamps_handling_method == "DROP_IF_CONFLICTING":
            df = df.drop_duplicates(subset=[datetime_column], keep=False)

        elif self.duplicate_timestamps_handling_method == "MEAN_MODE":
            mean_mode_aggregation = {}

            for column in numerical_columns:
                mean_mode_aggregation[column] = "mean"

            for column in categorical_columns:
                mean_mode_aggregation[column] = lambda x: x.mode().iloc[0] if not x.mode().empty else ""

            for column in other_columns:  # for columns that are not resampled, we don't care about the aggregation method
                mean_mode_aggregation[column] = lambda x: x.head(1)

            df = df.groupby(by=[datetime_column], as_index=False).agg(mean_mode_aggregation)

        lost_rows = total_rows - len(df.index)
        if lost_rows > 0:
            logger.warning("The time series {} lost {} rows after handling duplicate timestamps.".format(
                pretty_timeseries_identifiers(timeseries_identifier), lost_rows
            ))

        if df.empty:
            raise Exception("""The time series {} is empty after handling duplicate timestamps.
                Try to change/add identifier columns or change the duplicate timestamp handling method.
                """.format(pretty_timeseries_identifiers(timeseries_identifier)))

        return df
