import logging
import pandas as pd
import sys

from dataiku import Dataset
from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.core.doctor_constants import PREDICTION_LENGTH
from dataiku.core.doctor_constants import TIME_VARIABLE
from dataiku.doctor.timeseries.utils.recipes import get_input_df_and_parameters
from dataiku.doctor.timeseries.utils.recipes import load_model_partitions
from dataiku.doctor.utils import epoch_to_temporal
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_recipe_desc, log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.scoring_recipe_utils import add_fmi_metadata
from dataiku.doctor.utils.scoring_recipe_utils import add_prediction_time_metadata
from dataiku.doctor.utils.scoring_recipe_utils import generate_part_df_and_model_params
from dataiku.doctor.utils.scoring_recipe_utils import get_empty_pred_df
from dataiku.doctor.utils.scoring_recipe_utils import get_partition_columns
from dataiku.doctor.utils.scoring_recipe_utils import smmd_colnames

from dataiku.doctor.timeseries.score.scoring_handler import TimeseriesScoringHandler
from dataiku.doctor.diagnostics import default_diagnostics


logger = logging.getLogger("timeseries_scoring_recipe")


def main(
    model_folder,
    input_dataset_smartname,
    output_dataset_smartname,
    recipe_desc,
    script,
    preparation_output_schema,
    fmi,
    diagnostics_folder,
):
    logger.info("Starting time series scoring")
    model_folder_context = build_folder_context(model_folder)
    diagnostics_folder_context = build_folder_context(diagnostics_folder) if diagnostics_folder else None

    output_dataset = Dataset(output_dataset_smartname)

    input_df, _, core_params, quantiles, past_time_steps_to_include, columns_to_drop = get_input_df_and_parameters(
        model_folder_context, input_dataset_smartname, recipe_desc, preparation_output_schema, script
    )

    recipe_gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
    partition_columns = get_partition_columns(model_folder_context, core_params)
    partition_dispatch, model_partitions, partitions_fmis = load_model_partitions(model_folder_context, recipe_gpu_config, fmi)

    default_diagnostics.register_forecasting_scoring_callbacks()

    log_nvidia_smi_if_use_gpu(recipe_desc=recipe_desc)

    if "predictionLength" in recipe_desc:
        scoring_prediction_length = recipe_desc["predictionLength"]
    else:
        scoring_prediction_length = core_params[PREDICTION_LENGTH]
        logger.warning("Missing forecast prediction length, defaulting to horizon set in training: {}".format(scoring_prediction_length))

    part_dfs = []
    for part_df, partition_clf_and_params, partition_id in generate_part_df_and_model_params(
            input_df, partition_dispatch, core_params, model_partitions, raise_if_not_found=False):
        model_folder, preprocessing_params, clf, modeling_params, resolved_modeling_params, _ = partition_clf_and_params
        # Here we use model_folder as preproc_folder, as this code is only accessed by saved models (where preprocessing & model folders are the same), not trained analysis models
        scoring_handler = TimeseriesScoringHandler(
            core_params, preprocessing_params, modeling_params, resolved_modeling_params, clf, model_folder_context,
            diagnostics_folder_context=diagnostics_folder_context,
            scoring_prediction_length=scoring_prediction_length
        )
        algorithm = scoring_handler.algorithm
        if not algorithm.SUPPORTS_QUANTILES:
            quantiles = []
        part_forecasts_df = scoring_handler.score(
            part_df,
            preparation_output_schema,
            quantiles,
            past_time_steps_to_include,
            partition_columns=partition_columns,
            refit=recipe_desc["refitModel"],
        )

        if recipe_desc.get("outputModelMetadata", False):
            if partition_id is not None:
                add_fmi_metadata(part_forecasts_df, partitions_fmis[partition_id])
            else:
                add_fmi_metadata(part_forecasts_df, fmi)

        part_dfs.append(part_forecasts_df)

    if partition_dispatch:
        if part_dfs:
            forecasts_df = pd.concat(part_dfs, axis=0)
        else:
            logger.warning("All partitions found in dataset are unknown to the model, forecast will be empty")
            forecasts_df = get_empty_pred_df(input_df.columns, output_dataset.read_schema())
    else:
        forecasts_df = part_dfs[0]

    # TODO: clean up by keeping track of the list of window columns
    if modeling_params.get("isShiftWindowsCompatible", False):
        for column in forecasts_df.columns:
            if column.startswith("rolling_window:"):
                columns_to_drop.append(column)

    forecasts_df.drop(columns_to_drop, axis=1, inplace=True)

    time_variable = core_params[TIME_VARIABLE]
    output_columns = {sc["name"]:sc for sc in output_dataset.read_schema()}
    # put back the input columns in their non-preprocessed state (very visible for temporal columns)
    # and undo the part "remove timezone because of gluon" before writing out
    # note that this timezone clearing is only done on the TIME variable
    for name in forecasts_df.columns:
        sc = output_columns.get(name)
        if sc is None:
            continue # more like a oops, but who knows, maybe it'll just make an empty column
        column_type = sc.get('type')
        if name == time_variable:
            # time variable => it's a datetime64[ns] in the df
            if column_type == 'date' or column_type == 'datetimetz':
                # relocalize
                forecasts_df[name] = forecasts_df[name].dt.tz_localize('UTC')
            # nothing more to do for datetimenotz or dateonly
        else:
            # un-'normalize' column from numeric to the proper temporal types
            if column_type == 'date' or column_type == 'datetimetz' or column_type == 'datetimenotz' or column_type == 'dateonly':
                forecasts_df[name] = epoch_to_temporal(forecasts_df[name], column_type)

    if recipe_desc.get("outputModelMetadata", False):
        # add the "prediction time" and reorder the output model metadata Columns
        add_prediction_time_metadata(forecasts_df)
        ordered_columns = [col for col in forecasts_df.columns if col not in smmd_colnames] + smmd_colnames
        forecasts_df = forecasts_df[ordered_columns]

    output_dataset.write_dataframe(forecasts_df)
    logger.info("Wrote an output dataframe of shape {}".format(forecasts_df.shape))


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        main(
            sys.argv[1],
            sys.argv[2],
            sys.argv[3],
            dkujson.load_from_filepath(sys.argv[4]),
            dkujson.load_from_filepath(sys.argv[5]),
            dkujson.load_from_filepath(sys.argv[6]),
            sys.argv[7],
            sys.argv[8],
        )
