from collections import OrderedDict
import logging

import numpy as np
import pandas as pd

from dataiku.core import dkujson
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.timeseries.utils.pandas_compat import _groupby_compat
import hashlib


logger = logging.getLogger(__name__)


FULL_TIMESERIES_DF_IDENTIFIER = "__full_timeseries_df"
SINGLE_TIMESERIES_IDENTIFIER = "__single_timeseries_identifier"
FORECAST_COLUMN = "forecast"
FOLD_ID_COLUMN = "fold_id"


class JSONForecast:
    """Class of constants used to create the forecast objects for each time series in the forecasts.json file.
    For example:
    {
        ENCODED_TIMESERIES_IDENTIFIER: {
                "groundTruthTime": "["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04", "2021-01-05", "2021-01-06", "2021-01-07"],
                "forecastTime": "["2021-01-04", "2021-01-05", "2021-01-06", "2021-01-07"],
                "groundTruth": [1, 2, 3, 4, 5, 6, 7],
                "forecast": [4, 5, 6, 7],
                "quantiles": [
                    {"quantile": 0.1, "forecast": [0.2, 0.3, 0.2, 0.3], "futureForecast": [0.2, 0.3]},
                    ...,
                    {"quantile": 0.9, "forecast": [2.2, 3.3, 2.2, 3.3], "futureForecast": [2.2, 3.3]}
                ],
                "foldId": [0, 0, 1, 1],
                "futureTime": "["2021-01-08", "2021-01-09"],
                "futureForecast": [8, 9]
        }
    }
    'futureTime' and 'futureForecast' only exist if no external features were used.
    """
    GROUND_TRUTH_TIME = "groundTruthTime"
    GROUND_TRUTH_VALUES = "groundTruth"
    FORECAST_TIME = "forecastTime"
    FORECAST_VALUES = "forecast"
    FORECAST_FOLD_ID = "foldId"
    QUANTILE = "quantile"
    QUANTILES = "quantiles"
    FUTURE_TIME = "futureTime"
    FUTURE_FORECAST_VALUES = "futureForecast"
    FUTURE_FORECAST_CONTEXT_GROUND_TRUTH_TIME = "futureForecastContextTime"
    FUTURE_FORECAST_CONTEXT_VALUES = "futureForecastContext"


class ModelForecast:
    """Class of constants used in the forecast dictionaries returned by the models."""
    TIMESTAMPS = "timestamps"
    FORECAST_VALUES = "forecast_values"
    QUANTILES_FORECASTS = "quantiles_forecasts"


def encode_timeseries_identifier(timeseries_identifier_values, timeseries_identifier_columns):
    """Return a unique ordered encoding for the given identifier name-value pairs

    Args:
        timeseries_identifier_values (tuple/else): tuple of values of each identifier column or a single value (obtained from the pandas groupby)
        timeseries_identifier_columns (list): list of timeseries identifier column names
    """
    if not isinstance(timeseries_identifier_values, tuple):
        timeseries_identifier_values = (timeseries_identifier_values,)
    ordered_timeseries_identifier_list = sorted(zip(timeseries_identifier_columns, timeseries_identifier_values))
    return dkujson.dumps(OrderedDict(ordered_timeseries_identifier_list))


def get_dataframe_of_timeseries_identifier(full_df, timeseries_identifier):
    """Filter the dataframe by the given timeseries identifier encoding and return the filtered dataframe"""

    if full_df is None:
        return None

    if timeseries_identifier == SINGLE_TIMESERIES_IDENTIFIER:
        return full_df

    boolean_mask = None
    for timeseries_identifier_name, timeseries_identifier_value in dkujson.loads(timeseries_identifier).items():
        if boolean_mask is None:
            boolean_mask = full_df[timeseries_identifier_name] == timeseries_identifier_value
        else:
            boolean_mask &= full_df[timeseries_identifier_name] == timeseries_identifier_value

    return full_df[boolean_mask].reset_index(drop=True)


def timeseries_iterator(df, timeseries_identifier_columns):
    if timeseries_identifier_columns:
        for timeseries_identifier_values, df_of_timeseries_identifier in df.groupby(_groupby_compat(timeseries_identifier_columns)):
            timeseries_identifier = encode_timeseries_identifier(
                timeseries_identifier_values, timeseries_identifier_columns
            )
            yield timeseries_identifier, df_of_timeseries_identifier
    else:
        yield SINGLE_TIMESERIES_IDENTIFIER, df

def str_to_hash(s):
    """
    Converts a string to a MD5 hash.
    :param s: String to hash.
    :return: MD5 hash as string.
    """

    md5_hash = hashlib.md5()
    md5_hash.update(s.encode('utf-8')) # Convert the string to bytes
    key = md5_hash.hexdigest()[:8]

    return key


def build_quantile_column_name(quantile):
    return "quantile_" + "{:.4g}".format(quantile).replace(".", "")


def add_timeseries_identifiers_columns(df, timeseries_identifier):
    if timeseries_identifier != SINGLE_TIMESERIES_IDENTIFIER:
        for timeseries_identifier_column, timeseries_identifier_value in dkujson.loads(timeseries_identifier).items():
            df[timeseries_identifier_column] = timeseries_identifier_value


def pretty_timeseries_identifiers(timeseries_identifier):
    if not timeseries_identifier:
        return ""
    return "({})".format(
        " | ".join(
            "{}: {}".format(timeseries_identifier_column, timeseries_identifier_value)
            for timeseries_identifier_column, timeseries_identifier_value in dkujson.loads(timeseries_identifier).items()
        )
    )


def future_date_range(last_past_date, prediction_length, frequency, monthly_day_alignment=None):
    dates = pd.date_range(
        start=str(last_past_date),
        periods=prediction_length + 1,
        freq=frequency,
    )
    if monthly_day_alignment:
       dates = set_day_within_month(dates, monthly_day_alignment)
    return format_forecast_dates(dates[1:])

def format_forecast_dates(dates):
    return np.datetime_as_string(dates, unit='ms', timezone='naive').astype(str)

def set_day_within_month(dates, monthly_day_alignment):
    """ Set each date's day to monthly_day_alignment while keeping them within the original month they were in

    No-op if monthly_day_alignment is 0 or None
    """
    if not isinstance(dates, pd.DatetimeIndex):
        raise ValueError("dates must be a DatetimeIndex")
    if monthly_day_alignment is None or monthly_day_alignment == 0:
        return dates
    if monthly_day_alignment < 1 or monthly_day_alignment > 31:
        raise ValueError("monthly_day_alignment must be between 0 and 31")

    dates -= pd.offsets.MonthBegin()
    if monthly_day_alignment == 1:
        return dates

    adjusted_dates = []
    for date in dates:
        new_date = date + pd.DateOffset(days=monthly_day_alignment - 1)
        # this is to handle months with fewer days than monthly_day_alignment e.g. 30 in February
        if new_date.month != date.month:
            new_date = date + pd.offsets.MonthEnd(0)
        adjusted_dates.append(new_date)

    return pd.DatetimeIndex(adjusted_dates)


def ignored_timeseries_warning_message(ignored_timeseries_identifiers, explanation_message):
    return "The following time series {} been ignored ({}): {}".format(
        "have" if len(ignored_timeseries_identifiers) > 1 else "has",
        explanation_message,
        ", ".join(map(pretty_timeseries_identifiers, ignored_timeseries_identifiers))
    )


def ignored_timeseries_diagnostic_message(ignored_timeseries_identifiers, explanation_message):
    return "{} time series {} been ignored ({}). Check the logs for more details.".format(
        len(ignored_timeseries_identifiers),
        "have" if len(ignored_timeseries_identifiers) > 1 else "has",
        explanation_message
    )


def add_ignored_timeseries_diagnostics_and_logs(
        timeseries_identifier_columns, unseen_timeseries_identifiers, too_short_timeseries_identifiers,
        all_timeseries_ignored, min_required_length, recipe_type, diagnostic_type
    ):
    """Log warnings and add diagnostics for time series unseen during training or too short for scoring/evaluation.

    Args:
        timeseries_identifier_columns (list)
        unseen_timeseries_identifiers (list): Time series unseen during training of statistical models.
        too_short_timeseries_identifiers (list): Time series too short for scoring/evaluation.
        all_timeseries_ignored (bool): Whether all time series were ignored.
        min_required_length (int): Minimum required length for scoring/evaluation.
        recipe_type (str): Whether it's a scoring or evaluation recipe.
        diagnostic_type (DiagnosticType): DiagnosticType to use for the ignored series.

    Raises:
        ValueError: When all time series are ignored.
    """
    if timeseries_identifier_columns:
        if unseen_timeseries_identifiers:
            explanation_message = "because {} not seen during training".format(
                "they were" if len(unseen_timeseries_identifiers) > 1 else "it was"
            )
            logger.warning(ignored_timeseries_warning_message(unseen_timeseries_identifiers, explanation_message))
            diagnostics.add_or_update(
                diagnostic_type,
                ignored_timeseries_diagnostic_message(unseen_timeseries_identifiers, explanation_message)
            )

        if too_short_timeseries_identifiers:
            explanation_message = "because {} length {} smaller than the min required length for {} ({})".format(
                "their" if len(too_short_timeseries_identifiers) > 1 else "its",
                "are" if len(too_short_timeseries_identifiers) > 1 else "is",
                recipe_type,
                min_required_length,
            )
            logger.warning(ignored_timeseries_warning_message(too_short_timeseries_identifiers, explanation_message))
            diagnostics.add_or_update(
                diagnostic_type,
                ignored_timeseries_diagnostic_message(too_short_timeseries_identifiers, explanation_message)
            )

    if all_timeseries_ignored:
        error_reasons = []
        if unseen_timeseries_identifiers:
            error_reasons.append("{} not seen during training".format("were" if timeseries_identifier_columns else "was"))
        if too_short_timeseries_identifiers:
            error_reasons.append("{} shorter than the min required length of {} for {}".format(
                "are" if timeseries_identifier_columns else "is",
                min_required_length,
                recipe_type,
            ))
        error_prefix = "All input time series " if timeseries_identifier_columns else "Input time series "
        raise ValueError(error_prefix + " or ".join(error_reasons) + ". Check the logs for more details.")


def prefix_custom_metric_name(custom_metric_name):
    """
    We prefix custom metric names with "custom_" to avoid collision with the existing metrics.
    """
    return "custom_{}".format(custom_metric_name)


def log_df(dataframe_logger, df, time_variable, fold_id=None, prefix=""):
    fold_log = "(fold %s)" % (fold_id + 1) if fold_id is not None else ""
    if df is not None and time_variable in df and len(df) > 0:
        dataframe_logger.info("%s dataframe of shape (%s, %s) from %s to %s %s" % (prefix, df.shape[0], df.shape[1], df[time_variable].iloc[0], df[time_variable].iloc[-1], fold_log))