import json

import numpy as np
import pandas as pd

from dataiku.core import doctor_constants
from dataiku.core.doctor_constants import TARGET_VARIABLE, TIME_VARIABLE
from dataiku.doctor import step_constants
from dataiku.doctor.timeseries.models import TimeseriesForecastingAlgorithm
from dataiku.doctor.timeseries.perf.model_perf import TimeseriesModelScorer
from dataiku.doctor.timeseries.preparation.preprocessing import add_rolling_windows_for_training, add_rolling_windows_for_scoring, TimeseriesPreprocessing, \
    get_windows_list
from dataiku.doctor.timeseries.utils import SINGLE_TIMESERIES_IDENTIFIER, get_dataframe_of_timeseries_identifier, \
    ModelForecast
from dataiku.doctor.timeseries_interactive.interactive_scenarios_handler import TimeseriesInteractiveScenariosHandler
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.model_io import from_pkl

DATE_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'


class InteractiveScoringComputer:
    """Computes forecasts for interactive time series what-if scenarios."""

    def __init__(self, core_params, model_folder_context, preprocessing_folder_context, preprocessing_params,
                 modeling_params, resolved_params, full_df, external_features):
        """Initializes the computer.

        :param core_params: Core parameters of the model.
        :type core_params: dict
        :param model_folder_context: Folder context for the model.
        :type model_folder_context: FolderContext
        :param preprocessing_folder_context: Folder context for preprocessing data.
        :type preprocessing_folder_context: FolderContext
        :param preprocessing_params: Parameters for preprocessing.
        :type preprocessing_params: dict
        :param modeling_params: Parameters for modeling.
        :type modeling_params: dict
        :param resolved_params: Resolved parameters.
        :type resolved_params: dict
        :param full_df: The full dataset.
        :type full_df: pd.DataFrame
        :param external_features: List of external feature names.
        :type external_features: list
        """
        self.core_params = core_params
        self.external_features = external_features
        self.metrics_params = modeling_params["metrics"]
        self.model_folder_context = model_folder_context
        self.preprocessing_params = preprocessing_params

        listener = ProgressListener()
        self.algorithm = TimeseriesForecastingAlgorithm.build(modeling_params["algorithm"])
        self.timeseries_preprocessing = TimeseriesPreprocessing(preprocessing_folder_context, core_params,
                                                                preprocessing_params, modeling_params, listener, self.algorithm.EXTERNAL_FEATURES_COMPATIBILITY.supports_external_features())
        self.clf = from_pkl(model_folder_context)
        self.fit_before_predict = self.algorithm.should_fit_before_predict()

        interactive_scoring_model_folder_context = model_folder_context.get_subfolder_context("interactive-scoring-analysis")
        interactive_scoring_model_folder_context.create_if_not_exist()

        self.windows_list = get_windows_list(preprocessing_params) if modeling_params.get("isShiftWindowsCompatible", False) else []
        self.full_df = add_rolling_windows_for_training(full_df, core_params, self.windows_list, self.preprocessing_params, self.model_folder_context)

        self.interactive_scenarios_handler = TimeseriesInteractiveScenariosHandler(model_folder_context,
                                                                                   resolved_params,
                                                                                   preprocessing_params, core_params,
                                                                                   self.full_df)

    def compute_forecasts(self, identifier):
        """Computes forecasts for all scenarios of a given time series.

        :param identifier: The identifier of the time series to forecast.
        :type identifier: str
        :return: The forecasts for all scenarios of the time series.
        :rtype: dict
        """
        identifier_full_df = get_dataframe_of_timeseries_identifier(self.full_df, identifier)
        preprocess_on_resampled_df = self.algorithm.ONE_MODEL_FOR_MULTIPLE_TS
        self.timeseries_preprocessing.load_resources()
        self.timeseries_preprocessing.create_timeseries_preprocessing_handlers(identifier_full_df,
                                                                               preprocess_on_resampled_df,
                                                                               use_saved_resources=True)

        transformed_full_df = self.timeseries_preprocessing.process(
            identifier_full_df,
            step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN,
            preprocess_on_resampled_df
        )

        scenarios_metadata = self.interactive_scenarios_handler.get_scenarios_metadata(identifier)
        for scenario_id, _ in scenarios_metadata["names"].items():
            self._process_single_scenario(identifier=identifier,
                                          scenario_id=scenario_id,
                                          identifier_full_df=identifier_full_df,
                                          transformed_full_df=transformed_full_df,
                                          preprocess_on_resampled_df=preprocess_on_resampled_df)

        return self.interactive_scenarios_handler.get_scenarios_forecasts(identifier)

    def _process_single_scenario(self, identifier, scenario_id, identifier_full_df, transformed_full_df,
                                 preprocess_on_resampled_df):
        """Computes and saves the forecast of a single scenario.

        :param identifier: The identifier of the time series.
        :type identifier: str
        :param scenario_id: The identifier of the scenario.
        :type scenario_id: str
        :param identifier_full_df: The full dataframe for the given identifier.
        :type identifier_full_df: pd.DataFrame
        :param transformed_full_df: The preprocessed full dataframe for the given identifier.
        :type transformed_full_df: pd.DataFrame
        :param preprocess_on_resampled_df: Whether to preprocess on the resampled dataframe.
        :type preprocess_on_resampled_df: bool
        """
        scenario_df = self.interactive_scenarios_handler.get_scenario_df(identifier, scenario_id)
        scenario_df[self.core_params[TARGET_VARIABLE]] = 0
        scenario_df[self.core_params[TIME_VARIABLE]] = pd.to_datetime(scenario_df[self.core_params[TIME_VARIABLE]]).dt.tz_localize(None)
        if identifier != SINGLE_TIMESERIES_IDENTIFIER:
            for k, v in json.loads(identifier).items():
                scenario_df[k] = v
        for external_feature in self.external_features:
            if external_feature not in scenario_df.columns:
                scenario_df[external_feature] = pd.Series(None, index=scenario_df.index,
                                                          dtype=self.full_df[external_feature].dtype)

        historical_df_subset = transformed_full_df[
            pd.to_datetime(identifier_full_df[self.core_params[doctor_constants.TIME_VARIABLE]]) < np.datetime64(
                scenario_df[self.core_params[doctor_constants.TIME_VARIABLE]].iloc[0])]

        scenario_df = add_rolling_windows_for_scoring(scenario_df, self.core_params, self.windows_list, self.preprocessing_params, self.model_folder_context)
        transformed_scenario_df = self.timeseries_preprocessing.process(
            scenario_df,
            step_constants.ProcessingStep.STEP_PREPROCESS_TEST,
            preprocess_on_resampled_df
        )

        model_scorer = TimeseriesModelScorer.build(self.core_params, self.metrics_params, use_external_features=True)
        timeseries_forecasts = \
            model_scorer.predict_all_test_timesteps(self.clf, historical_df_subset, transformed_scenario_df,
                                                    self.fit_before_predict)[identifier]
        forecasts = {
            "forecast": timeseries_forecasts[ModelForecast.FORECAST_VALUES],
            "forecastTime": timeseries_forecasts[ModelForecast.TIMESTAMPS],
            "quantiles": []
        }
        for idx, q in enumerate(model_scorer.quantiles):
            forecasts["quantiles"].append({
                "forecast": timeseries_forecasts[ModelForecast.QUANTILES_FORECASTS][idx],
                "quantile": q
            })
        serializable_forecasts = model_scorer.remove_nanif_in_one_forecast(forecasts)
        scenarios_forecasts = {
            "forecastTime": transformed_scenario_df[self.core_params[doctor_constants.TIME_VARIABLE]].astype(str).values,
            "groundTruthTime": historical_df_subset[self.core_params[doctor_constants.TIME_VARIABLE]][-max(2, len(scenario_df)):].dt.strftime(DATE_FORMAT).values,
            "groundTruth": historical_df_subset[self.core_params[doctor_constants.TARGET_VARIABLE]][-max(2, len(scenario_df)):].values,
        }
        scenarios_forecasts["forecast"] = serializable_forecasts["forecast"]
        scenarios_forecasts["quantiles"] = serializable_forecasts["quantiles"]
        self.interactive_scenarios_handler.write_scenarios(scenarios_forecasts, identifier, scenario_id)
