import logging
import pandas as pd
import sys

from dataiku import Dataset
from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.core import doctor_constants
from dataiku.core.doctor_constants import TIME_VARIABLE
from dataiku.doctor import utils
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.exception import EmptyDatasetException
from dataiku.doctor.timeseries.evaluate.evaluation_handler import TimeseriesEvaluationHandler
from dataiku.doctor.timeseries.perf.model_perf import AGGREGATED_TIMESERIES_METRICS
from dataiku.doctor.timeseries.utils import prefix_custom_metric_name, _groupby_compat
from dataiku.doctor.timeseries.utils.recipes import get_input_df_and_parameters
from dataiku.doctor.timeseries.utils.recipes import load_model_partitions
from dataiku.doctor.utils import epoch_to_temporal
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_recipe_desc, log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.scoring_recipe_utils import generate_part_df_and_model_params
from dataiku.doctor.utils.scoring_recipe_utils import get_empty_pred_df
from dataiku.doctor.utils.scoring_recipe_utils import get_partition_columns

logger = logging.getLogger(__name__)


def main(model_folder, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname, recipe_desc,
         script, preparation_output_schema, model_evaluation_store_folder, diagnostics_folder, fmi):
    has_output_dataset = output_dataset_smartname is not None and len(output_dataset_smartname) > 0
    has_metrics_dataset = metrics_dataset_smartname is not None and len(metrics_dataset_smartname) > 0

    nb_forecast_timesteps = recipe_desc.get('maxNbForecastTimeSteps')  # could be None

    compute_per_timeseries_metrics = recipe_desc["computePerTimeseriesMetrics"]
    output_metrics = recipe_desc["metrics"]
    output_metrics += [prefix_custom_metric_name(custom_metric) for custom_metric in recipe_desc["customMetrics"]]

    model_folder_context = build_folder_context(model_folder)
    model_evaluation_store_folder_context = build_folder_context(model_evaluation_store_folder) if model_evaluation_store_folder else None
    diagnostics_folder_context = build_folder_context(diagnostics_folder) if diagnostics_folder else None
    input_df, input_df_copy_unnormalized, core_params, quantiles, _, columns_to_drop = get_input_df_and_parameters(
        model_folder_context, input_dataset_smartname, recipe_desc, preparation_output_schema, script
    )

    if input_df.empty:
        raise EmptyDatasetException("The evaluation dataset can not be empty. Check the input dataset or the recipe sampling configuration.")

    log_nvidia_smi_if_use_gpu(recipe_desc=recipe_desc)

    logger.info("Scoring data for evaluation")
    partition_columns = get_partition_columns(model_folder_context, core_params)
    recipe_gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
    partition_dispatch, partitions, _ = load_model_partitions(model_folder_context, recipe_gpu_config, fmi)

    default_diagnostics.register_forecasting_evaluation_callbacks()

    part_forecasts_dfs = []
    part_metrics_dfs = []
    for part_df, partition_params, _ in generate_part_df_and_model_params(input_df, partition_dispatch, core_params,
                                                                       partitions, raise_if_not_found=False):

        if nb_forecast_timesteps is None:
            # backwards compat for eval recipes created pre-14.1
            max_nb_forecast_horizons = recipe_desc['maxNbForecastHorizons']
            nb_forecast_timesteps = core_params["predictionLength"] * max_nb_forecast_horizons

        model_folder, preprocessing_params, clf, modeling_params, metrics_params = partition_params
        # We can use model_folder here for the preprocessing_folder param of TimeseriesEvaluationHandler, as the eval
        # handler is always used by SMs (where preprocessing & model folders are the same)
        evaluation_handler = TimeseriesEvaluationHandler(
            core_params, preprocessing_params, modeling_params, metrics_params, clf, model_folder_context, model_evaluation_store_folder_context, diagnostics_folder_context
        )
        part_forecasts_df, part_metrics_df = evaluation_handler.evaluate(
            part_df,
            preparation_output_schema,
            quantiles,
            nb_forecast_timesteps,
            output_metrics,
            partition_columns,
            compute_metrics=has_metrics_dataset,
            compute_per_timeseries_metrics=compute_per_timeseries_metrics,
            refit=recipe_desc["refitModel"],
        )
        if not part_forecasts_df.empty:
            part_forecasts_dfs.append(part_forecasts_df)
        if not part_metrics_df.empty:
            part_metrics_dfs.append(part_metrics_df)

    # write scored data
    if has_output_dataset:
        output_dataset = Dataset(output_dataset_smartname)

        if partition_dispatch:
            if part_forecasts_dfs:
                forecasts_df = pd.concat(part_forecasts_dfs, axis=0)
            else:
                logger.warning("All partitions found in dataset are unknown to the model, forecast will be empty")
                forecasts_df = get_empty_pred_df(input_df.columns, output_dataset.read_schema())
        else:
            forecasts_df = part_forecasts_dfs[0]

        forecasts_df.drop(columns_to_drop, axis=1, inplace=True)

        time_variable = core_params[TIME_VARIABLE]
        output_columns = {sc["name"]:sc for sc in output_dataset.read_schema()}
        # put back the input columns in their non-preprocessed state (very visible for temporal columns)
        # and undo the part "remove timezone because of gluon" before writing out
        # note that this timezone clearing is only done on the TIME variable
        for name in forecasts_df.columns:
            sc = output_columns.get(name)
            if sc is None:
                continue # more like a oops, but who knows, maybe it'll just make an empty column
            column_type = sc.get('type')
            if name == time_variable:
                # time variable => it's a datetime64[ns] in the df
                if column_type == 'date' or column_type == 'datetimetz':
                    # relocalize
                    forecasts_df[name] = forecasts_df[name].dt.tz_localize('UTC')
                # nothing more to do for datetimenotz or dateonly
            else:
                # un-'normalize' column from numeric to the proper temporal types
                if column_type == 'date' or column_type == 'datetimetz' or column_type == 'datetimenotz' or column_type == 'dateonly':
                    forecasts_df[name] = epoch_to_temporal(forecasts_df[name], column_type)

        output_dataset.write_dataframe(forecasts_df)
        logger.info("Wrote an output dataframe of shape {}".format(forecasts_df.shape))

    # write metrics dataset
    if has_metrics_dataset:
        metrics_dataset = Dataset(metrics_dataset_smartname)

        if partition_dispatch:
            if part_metrics_dfs:
                metrics_df = pd.concat(part_metrics_dfs, axis=0)
                metrics_df = _aggregate_partition_metrics(
                    metrics_df,
                    compute_per_timeseries_metrics,
                    core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS]
                )
            else:
                logger.warning("All partitions found in dataset are unknown to the model, metrics will be empty")
                metrics_df = pd.DataFrame(columns=[c["name"] for c in metrics_dataset.read_schema()])
        else:
            metrics_df = part_metrics_dfs[0]
    
        metrics_dataset.write_dataframe(metrics_df)


def _aggregate_partition_metrics(metrics_df, compute_per_timeseries_metrics, timeseries_identifiers):
    """
    Aggregate all partition metrics by mean (in partition dispatch mode), same as in the StratifiedMetricsAggregator
    """
    # TODO check if custom metrics are used here
    all_metrics_columns = list(set(AGGREGATED_TIMESERIES_METRICS).intersection(set(metrics_df.columns)))

    if compute_per_timeseries_metrics:
        aggregated_metrics_df = metrics_df.groupby(_groupby_compat(timeseries_identifiers))[all_metrics_columns].mean().reset_index()
    else:
        aggregated_metrics_df = pd.DataFrame(metrics_df[all_metrics_columns].mean(axis=0)).T

    aggregated_metrics_df["date"] = utils.get_datetime_now_utc()

    return aggregated_metrics_df


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        main(
            sys.argv[1],
            sys.argv[2],
            sys.argv[3],
            sys.argv[4],
            dkujson.load_from_filepath(sys.argv[5]),
            dkujson.load_from_filepath(sys.argv[6]),
            dkujson.load_from_filepath(sys.argv[7]),
            sys.argv[8],
            sys.argv[9],
            sys.argv[10],
        )
