import logging

import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype

from dataiku.core import doctor_constants
from dataiku.core import intercom
from dataiku.doctor import step_constants
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.doctor.exception import TimeseriesResamplingException
from dataiku.doctor.prediction.common import get_initial_intrinsic_perf_data
from dataiku.doctor.timeseries.utils.pandas_compat import str_to_datetime_compat
from dataiku.doctor.utils.model_io import to_pkl
from dataiku.doctor.timeseries.perf.model_perf import PER_TIMESERIES_METRICS
from dataiku.doctor.timeseries.perf.model_perf import TIMESERIES_AGGREGATED_METRICS
from dataiku.doctor.timeseries.perf.model_perf import TimeseriesModelIntrinsicScorer
from dataiku.doctor.timeseries.perf.model_perf import TimeseriesModelScorer
from dataiku.doctor.timeseries.preparation.preprocessing import get_external_features
from dataiku.doctor.timeseries.preparation.preprocessing import get_filtered_features
from dataiku.doctor.timeseries.preparation.preprocessing import resample_timeseries
from dataiku.doctor.timeseries.preparation.resampling.utils import get_frequency
from dataiku.doctor.timeseries.utils import FOLD_ID_COLUMN, log_df, get_dataframe_of_timeseries_identifier
from dataiku.doctor.timeseries.utils import FORECAST_COLUMN
from dataiku.doctor.timeseries.utils import JSONForecast
from dataiku.doctor.timeseries.utils import SINGLE_TIMESERIES_IDENTIFIER
from dataiku.doctor.timeseries.utils import add_timeseries_identifiers_columns
from dataiku.doctor.timeseries.utils import build_quantile_column_name
from dataiku.doctor.timeseries.utils import pretty_timeseries_identifiers
from dataiku.doctor.timeseries.utils import timeseries_iterator

logger = logging.getLogger(__name__)


class TimeseriesTrainingHandler(object):
    def __init__(self, core_params, model_scorer, use_external_features, algorithm, model_folder_context, listener):
        self.core_params = core_params
        self.model_folder_context = model_folder_context

        self.listener = listener

        self.prediction_length = core_params[doctor_constants.PREDICTION_LENGTH]
        self.time_variable = core_params[doctor_constants.TIME_VARIABLE]
        self.target_variable = core_params[doctor_constants.TARGET_VARIABLE]
        self.timeseries_identifier_columns = core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS]

        self.quantiles = sorted(core_params[doctor_constants.QUANTILES])

        self.frequency = get_frequency(core_params)

        self.algorithm = algorithm

        self.use_external_features = use_external_features

        self.model_scorer = model_scorer

        self.model_intrinsic_perf = {}

    def train(self, estimator, modeling_params, train_df, test_df=None, historical_df=None, preprocessed_external_features=None,
              shift_map=None, score_model=False, fold_id=None,
              save_model=False, step_name=step_constants.ProcessingStep.STEP_FITTING):
        with self.listener.push_step(step_name):
            log_df(logger, train_df, self.time_variable, fold_id, "Training model on")
            estimator.fit(train_df, external_features=preprocessed_external_features, shift_map=shift_map)
            
        if save_model:
            actual_params = self.algorithm.get_actual_params(modeling_params, estimator, fit_params=None)
            with self.listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
                to_pkl(estimator, self.model_folder_context)
                self.model_folder_context.write_json("actual_params.json", actual_params)

            # predict future values if no external features are used
            if not self.use_external_features:
                forecasts_by_timeseries = estimator.predict(train_df, None, self.quantiles)
                self.model_scorer.append_future_forecasts(forecasts_by_timeseries, train_df)

        if score_model and test_df is not None:
            with self.listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
                fit_before_predict = self.algorithm.should_fit_before_predict()
                if fit_before_predict:
                    if self._max_number_of_horizons_in_test(historical_df, test_df) < 2:
                        # If there is a single horizon in the test set we don't need to refit before predict and reuse training fit
                        # When train_df and test_df are contiguous, historical_df will be equal to train_df, and the refit is not needed
                        # When there is a gap between train_df and test_df (custom interval), historical_df will be larger than train_df,
                        # but skipping mandatory refit, the model will consider them to be contiguous during training.
                        fit_before_predict = False

                log_df(logger, test_df, self.time_variable, fold_id, "Scoring model on")
                forecasts_by_timeseries = self.model_scorer.predict_all_test_timesteps(estimator, historical_df, test_df, fit_before_predict)
                self.model_scorer.score(historical_df, test_df, forecasts_by_timeseries, fold_id, append_forecasts=True, append_scores=True)

    def _max_number_of_horizons_in_test(self, historical_df, test_df):
        max_nb_horizons = 0
        for timeseries_identifier, _ in timeseries_iterator(historical_df, self.timeseries_identifier_columns):
            test_length_for_timeseries_identifier = len(get_dataframe_of_timeseries_identifier(test_df, timeseries_identifier).index)
            nb_horizons_for_timeseries_identifier = (test_length_for_timeseries_identifier - self.prediction_length) // self.model_scorer.prediction_offset + 1
            max_nb_horizons = max(nb_horizons_for_timeseries_identifier, max_nb_horizons)
        return max_nb_horizons

    def save_intrinsic_scores_and_forecasts(self, full_df, modeling_params, preprocessing_params, estimator):
        initial_intrinsic_perf_data = get_initial_intrinsic_perf_data(full_df.to_numpy(), False)

        total_nb_timeseries = len(self.model_scorer.forecasts)
        initial_intrinsic_perf_data["totalNbOfTimeseries"] = total_nb_timeseries

        model_intrinsic_scorer = TimeseriesModelIntrinsicScorer(
            modeling_params,
            preprocessing_params,
            estimator,
            self.algorithm,
            self.prediction_length,
            self.frequency,
            initial_intrinsic_perf_data
        )

        self.update_intrinsic_perf_data(model_intrinsic_scorer.score())

        self.model_folder_context.write_json("iperf.json", self.model_intrinsic_perf)
        
        max_nb_timeseries = intercom.jek_or_backend_get_call("ml/prediction/get-max-nb-timeseries-in-forecast-charts")

        # save historical and forecasts values per timeseries for only the first max_nb_timeseries time series
        if (total_nb_timeseries > max_nb_timeseries):
            truncated_forecasts = {
                timeseries_identifier: self.model_scorer.forecasts[timeseries_identifier]
                for timeseries_identifier in list(self.model_scorer.forecasts)[:max_nb_timeseries]
            }
            self.model_folder_context.write_json("forecasts.json.gz",
                                                 {"perTimeseries": self.model_scorer.remove_naninf_in_forecasts(truncated_forecasts)})
            diagnostic_message = "Only the first {} out of the total {} trained time series will be displayed in the model report forecast charts.".format(max_nb_timeseries, total_nb_timeseries)
            diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS, diagnostic_message)
        else:
            self.model_folder_context.write_json("forecasts.json.gz",
                                                 {"perTimeseries": self.model_scorer.remove_naninf_in_forecasts(self.model_scorer.forecasts)})

        if total_nb_timeseries > 1:
            # save only the first time series forecasts in another file to be used in the model snippets
            first_timeseries_identifier, first_timeseries_forecasts = next(iter(self.model_scorer.forecasts.items()))
            self.model_folder_context.write_json("one_forecast.json",
                                                 {"perTimeseries": {first_timeseries_identifier: self.model_scorer.remove_nanif_in_one_forecast(first_timeseries_forecasts)}})

    def _get_perf_file_name(self):
        """Save compressed perf json file only if there are multiple time series"""
        if self.timeseries_identifier_columns:
            return "perf.json.gz"
        else:
            return "perf.json"

    def save_scores(self):
        """If no kfold, save metrics directly in model_folder/perf.json.
        If kfold, save each fold metrics in model_folder/fold_{fold_id}/perf.json
        and then save mean and std metrics in model_folder/perf.json.
        """
        if len(self.model_scorer.scores) == 1:
            self.model_folder_context.write_json(self._get_perf_file_name(), self.model_scorer.remove_naninf(self.model_scorer.scores[0]))
        else:
            for fold_id, score in enumerate(self.model_scorer.scores):
                fold_model_folder_context = self.model_folder_context.get_subfolder_context("fold_{}".format(fold_id))
                fold_model_folder_context.create_if_not_exist()
                fold_model_folder_context.write_json(self._get_perf_file_name(), self.model_scorer.remove_naninf(score))

            folds_aggregation = {}
            # aggregate aggregation metrics of each fold
            folds_timeseries_aggregated_metrics_df = pd.DataFrame([score[TIMESERIES_AGGREGATED_METRICS] for score in self.model_scorer.scores])
            folds_aggregation[TIMESERIES_AGGREGATED_METRICS] = TimeseriesModelScorer.aggregate_metrics_per_fold(folds_timeseries_aggregated_metrics_df)

            # aggregate per timeseries metrics of each fold
            folds_aggregation[PER_TIMESERIES_METRICS] = {}

            # Dict mapping relevant fold metrics to each timeseries identifier
            # A key is a timeseries identifier.
            # A value is a list of dicts, where each dict stores the metrics
            # computed for a given fold (only when they can be evaluated).
            per_timeseries_fold_metrics = {}
            for score in self.model_scorer.scores:
                for identifier in score[PER_TIMESERIES_METRICS]:
                    if identifier not in per_timeseries_fold_metrics:
                        per_timeseries_fold_metrics[identifier] = []
                    per_timeseries_fold_metrics[identifier].append(score[PER_TIMESERIES_METRICS][identifier])
            for identifier in per_timeseries_fold_metrics.keys():
                folds_aggregation[PER_TIMESERIES_METRICS][identifier] = TimeseriesModelScorer.aggregate_metrics_per_fold(
                    pd.DataFrame(per_timeseries_fold_metrics[identifier])
                )
            self.model_folder_context.write_json(self._get_perf_file_name(), self.model_scorer.remove_naninf(folds_aggregation))

    def save_predicted_data(self, resampled_df, preprocessing_params, schema):
        multiple_folds = len(self.model_scorer.scores) > 1
        predicted_dfs = []
        for timeseries_identifier, forecast_dict in self.model_scorer.forecasts.items():
            # 1. retrieve historical values
            historical_df_of_timeseries_identifier = self._build_timeseries_df_from_json_forecast(
                forecast_dict, JSONForecast.GROUND_TRUTH_TIME, JSONForecast.GROUND_TRUTH_VALUES,
                timeseries_identifier=timeseries_identifier
            )

            # 2. retrieve evaluation forecast values
            forecast_df_of_timeseries_identifier = self._build_timeseries_df_from_json_forecast(
                forecast_dict, JSONForecast.FORECAST_TIME, JSONForecast.FORECAST_VALUES, is_forecast=True
            )
            if multiple_folds:
                forecast_df_of_timeseries_identifier[FOLD_ID_COLUMN] = forecast_dict[JSONForecast.FORECAST_FOLD_ID]

            predicted_df_of_timeseries_identifier = historical_df_of_timeseries_identifier.merge(
                forecast_df_of_timeseries_identifier, on=self.time_variable, how="left"
            )

            # 3. retrieve future forecast values if no external features were used
            if not self.use_external_features:
                future_forecast_df_of_timeseries_identifier = self._build_timeseries_df_from_json_forecast(
                    forecast_dict, JSONForecast.FUTURE_TIME, JSONForecast.FUTURE_FORECAST_VALUES,
                    timeseries_identifier=timeseries_identifier, is_forecast=True
                )

                predicted_df_of_timeseries_identifier = pd.concat(
                    [predicted_df_of_timeseries_identifier, future_forecast_df_of_timeseries_identifier], axis=0
                )

            predicted_dfs.append(predicted_df_of_timeseries_identifier)

        predicted_df = pd.concat(predicted_dfs, ignore_index=True)
        resampled_df[self.time_variable] = resampled_df[self.time_variable].dt.tz_localize('UTC')

        if self.use_external_features:
            # we merge with resampled_df to retrieve the external features
            predicted_df = predicted_df.merge(resampled_df, on=[self.time_variable, self.target_variable] + (self.timeseries_identifier_columns or []), how="left")

        # converting date column to iso format to get a 'Date' meaning
        predicted_df[self.time_variable] = predicted_df[self.time_variable].dt.strftime('%Y-%m-%dT%H:%M:%S.%f').str.slice(0, 23) + 'Z'

        # some integer columns (target, numerical external features, fold id) can contain NaN values in the predicted data but in pandas columns with NaNs cannot be converted to int
        # so we need to apply a custom method to display these columns as int in the predicted data tab
        integer_columns = [
            column["name"]
            for column in schema["columns"]
            if column["type"] in ["tinyint", "smallint", "int", "bigint"]
            and column["name"] in predicted_df
            and is_numeric_dtype(predicted_df[column["name"]])  # categorical integer features don't contain any decimals and would fail on np.isnan, this fixes bug of [sc-148437]
        ]

        if multiple_folds:
            integer_columns.append(FOLD_ID_COLUMN)

        for integer_column in integer_columns:
            predicted_df[integer_column] = predicted_df[integer_column].apply(lambda x: "{:.0f}".format(x) if not np.isnan(x) else "")

        ordered_columns =  (self.timeseries_identifier_columns or []) + [self.time_variable]
        if self.use_external_features:
            ordered_columns.extend(get_external_features(preprocessing_params))
        if multiple_folds:
            ordered_columns.append(FOLD_ID_COLUMN)
        ordered_columns.extend([self.target_variable, FORECAST_COLUMN])
        ordered_columns.extend(build_quantile_column_name(q) for q in self.quantiles)

        predicted_df = predicted_df[ordered_columns]
        with self.model_folder_context.get_file_path_to_write("predicted.csv") as predicted_file:
            predicted_df.to_csv(predicted_file, sep="\t", header=True, index=False)

    def _build_timeseries_df_from_json_forecast(self, forecast_dict, time_key, target_key, timeseries_identifier=None, is_forecast=False):
        """Create a dataframe from a forecast dict of a single timeseries. Optionally adds identifiers and quantile columns  

        Args:
            forecast_dict (dict): forecast dict of a single timeseries of the forecast.json file
            time_key (str): key mapping to the time values in forecast_dict
            target_key (str): key mapping to the target values in forecast_dict
            timeseries_identifier (str, optional): if provided, timeseries identifier columns are added to the dataframe
            is_forecast (bool, optional): if True, adds the forecast column and the quantile column, else, adds the target column 
        """
        target_column = FORECAST_COLUMN if is_forecast else self.target_variable
        df = pd.DataFrame(
            {
                self.time_variable: str_to_datetime_compat(forecast_dict[time_key]),
                target_column: forecast_dict[target_key],
            }
        )
        if is_forecast:
            for quantile_forecast in forecast_dict[JSONForecast.QUANTILES]:
                df[build_quantile_column_name(quantile_forecast[JSONForecast.QUANTILE])] = quantile_forecast[target_key]

        if timeseries_identifier:
            add_timeseries_identifiers_columns(df, timeseries_identifier)

        return df

    def update_intrinsic_perf_data(self, data):
        self.model_intrinsic_perf.update(data)


def resample_for_training(df, schema, resampling_params, core_params, preprocessing_params, supports_shifts, compute_zero_target_ratio_diagnostic):
    """Resample the target and the external features together"""
    include_roles = ["TARGET", "INPUT", "INPUT_PAST_ONLY"] if supports_shifts else ["TARGET", "INPUT"]
    numerical_columns = get_filtered_features(preprocessing_params, include_types=["NUMERIC"], include_roles=include_roles)
    categorical_columns = get_filtered_features(preprocessing_params, exclude_types=["NUMERIC"], include_roles=include_roles)

    timeseries_identifier_columns = core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS]
    target_variable = core_params[doctor_constants.TARGET_VARIABLE]

    # If 1 time series has strictly less than 2 valid (non-NaN) target values, then it cannot be resampled for training 
    for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(df, timeseries_identifier_columns):
        if df_of_timeseries_identifier[target_variable].count() < 2:
            if timeseries_identifier == SINGLE_TIMESERIES_IDENTIFIER:
                error_message = "Input time series"
            else:
                error_message = "Time series {}".format(pretty_timeseries_identifiers(timeseries_identifier))
            error_message += " cannot be resampled because its target column contains less than 2 valid values."
            raise TimeseriesResamplingException(error_message)

    return resample_timeseries(
            df, schema, resampling_params, core_params, numerical_columns, categorical_columns,
            compute_zero_target_ratio_diagnostic=compute_zero_target_ratio_diagnostic,
        )
