import datetime
from dateutil.tz import tzutc
import logging
import pandas as pd

from dataiku import Dataset
from dataiku.core import default_project_key
from dataiku.core.base import PartitionEscaper
from dataiku.doctor.utils import normalize_dataframe
from dataiku.doctor.utils import ml_dtypes_from_dss_schema

logger = logging.getLogger(__name__)

prefix_col_name = 'smmd_';
smmd_colnames = [
    prefix_col_name+'savedModelId',
    prefix_col_name+'modelVersion',
    prefix_col_name+'fullModelId',
    prefix_col_name+'predictionTime'
]


def add_output_model_metadata(df, fmi):
    """
    add to the passed df 4 columns with the model metadata info, used by the scoring recipe.

    This function has the same column names as ScoringRecipeUtils.java. you have to change both
    in case of update.

    :param df: pd.Dataframe, dataframe where we add model metadata info
    :param fmi: String, full model id to extract required model metadata
    :return: void
    """
    add_fmi_metadata(df, fmi)
    add_prediction_time_metadata(df)


def add_fmi_metadata(df, fmi):
    """
    add to the passed df 3 columns with the fmi model metadata info, used by the scoring recipe.

    :param df: pd.Dataframe, dataframe where we add model metadata info
    :param fmi: String, full model id to extract required model metadata
    :return: void
    """
    logger.info("Add FMI's Metadata Columns")
    df[smmd_colnames[0]] = fmi.split('-')[2]
    df[smmd_colnames[1]] = fmi.split('-')[3]
    df[smmd_colnames[2]] = fmi


def add_prediction_time_metadata(df):
    """
    add to the passed df a columns with the 'prediction time' model metadata info.

    It's not the exact timer for every prediction but a global execution time used for the lineage of a model.

    :param df: pd.Dataframe, dataframe where we add model metadata info
    :return: void
    """
    logger.info("Add prediction time Metadata Column")
    pred_time = datetime.datetime.now(tzutc())
    df[smmd_colnames[3]] = "{}.{}Z".format(
        pred_time.strftime('%Y-%m-%dT%H:%M:%S'),
        int(pred_time.microsecond/1000)
    )


def get_dataframe_dtypes_info(schema, feature_params, prediction_type=None, partition_columns=None):
    (names, dtypes, parse_date_columns) = Dataset.get_dataframe_schema_st(
        schema["columns"],
        parse_dates=True,
        infer_with_pandas=False
    )
    logger.info("Reading with INITIAL dtypes: {}".format(dtypes))
    dtypes = ml_dtypes_from_dss_schema(
        schema,
        feature_params,
        prediction_type=prediction_type
    )

    if partition_columns:
        logger.info("Forcing dtype of partition columns '%s' to be 'str" % (partition_columns))
        # mind the case of temporal columns: if you set the dtype on them, then pandas parses
        # them, then casts to str. For "date" (=datetime with tz) it's fine because the string
        # representation can be parsed again as it's a proper iso8601 string, but for dateonly
        # and datetime no tz, pandas makes a datetime64[ns] (no tz indication) then casts to
        # string as a unix timestamp in nanoseconds. And that can't be parsed back easily :(
        # => if a partition column is in the parse_date_columns, then remove it from there
        parse_date_columns_list = parse_date_columns if isinstance(parse_date_columns, list) else []
        for i, n in enumerate(names):
            if (n in partition_columns) and (i in parse_date_columns_list):
                parse_date_columns.remove(i)
        for column in partition_columns:
            dtypes[column] = "str"

    logger.info("Reading with dtypes: {}".format(dtypes))
    for i in range(len(names)):
        logger.info("Column %s = %s (dtype=%s)" % (i, names[i], dtypes.get(names[i], None)))
    return names, dtypes, parse_date_columns


def dataframe_iterator(dataset, names, dtypes, parse_date_columns, feature_params,
                                       batch_size=10000, sampling="head", float_precision=None,
                                       normalize=True):
    """
    Wrapper around the Dataset.iter_dataframes_forced_types method to add a custom date_parser that convert dates to UTC.
    Also normalize the dataframes if normalize is True and always return a copy of the un-normalized dataframes.
    Returns the whole dataframe in the first iteration when batch_size is None.
    Note: date_parser is deprecated after pandas 2.0
    """
    for df in dataset.iter_dataframes_forced_types(
        names, dtypes, parse_date_columns,
        chunksize=batch_size,
        sampling=sampling,
        float_precision=float_precision,
        date_parser=lambda col: pd.to_datetime(col, utc=True)
    ):
        logger.info("Got a dataframe of shape: {}".format(df.shape))
        df.index = range(df.shape[0])
        input_df_copy_unnormalized = df.copy()
        if normalize:
            logger.info("Normalizing dataframe")
            normalize_dataframe(df, feature_params)
            for col in df:
                logger.info("Normalized column: %s -> %s" % (col, df[col].dtype))
        yield df, input_df_copy_unnormalized


def is_partition_dispatch(model_folder_context):
    # File should be compressed for all models from DSS >= 12.6.3 but keep checking both for backward compatibility
    return model_folder_context.isfile("parts.json.gz") or model_folder_context.isfile("parts.json")


def get_partition_columns(model_folder_context, core_params):
    if is_partition_dispatch(model_folder_context):
        return core_params.get("partitionedModel", {}).get("dimensionNames")
    return None


def generate_part_df_and_model_params(input_df, partition_dispatch, core_params, partitions,
                                      raise_if_not_found=False):
    if not partition_dispatch:
        yield (input_df, partitions["NP"], None)
    else:
        partitioning_params = core_params["partitionedModel"]
        for part_value, part_df in input_df.groupby(partitioning_params["dimensionNames"], sort=False):
            partition_id = PartitionEscaper.build_partition_id(part_value, partitioning_params["dimensionNames"], partitioning_params["dimensionTypes"])
            if partition_id not in partitions.keys():
                if raise_if_not_found:
                    raise ValueError("Unknown model partition %s" % partition_id)
                else:
                    logger.info("Unknown model partition %s, discarding %s rows" % (partition_id, part_df.shape[0]))
                    continue
            else:
                logger.info("Handling partition '%s'" % partition_id)
                yield (part_df, partitions[partition_id], partition_id)


def get_input_parameters(model_folder_context, input_dataset_smartname, preparation_output_schema, script, managed_folder_smart_id=None):

    # Obtain a streamed result of the preparation
    input_dataset = Dataset(input_dataset_smartname)
    logger.info("Will do preparation, output schema: %s" % preparation_output_schema)
    input_dataset.set_preparation_steps(script["steps"], preparation_output_schema,
                                        context_project_key=default_project_key())

    # Load common model params
    core_params = model_folder_context.read_json("core_params.json")

    if managed_folder_smart_id:
        # replace the managedFolder used during the training by the one set in the recipe.
        core_params["managedFolderSmartId"] = managed_folder_smart_id

    # Only using cross-partition safe params: in per_feature[FEATURE_NAME], .type and .role
    feature_preproc = model_folder_context.read_json("rpreprocessing_params.json")["per_feature"]

    partition_columns = get_partition_columns(model_folder_context, core_params)

    names, dtypes, parse_date_columns = get_dataframe_dtypes_info(
        preparation_output_schema, feature_preproc, prediction_type=core_params["prediction_type"],
        partition_columns=partition_columns
    )

    return input_dataset, core_params, feature_preproc, names, dtypes, parse_date_columns


def get_empty_pred_df(input_df_columns, output_dataset_schema):
    """
        Output an empty dataframe with the relevant added columns (proba_classX, predict, cond_output...)

        Output schema of Scoring recipe can vary a lot depending on the parameters of the recipe:
         * prediction type (with proba for probabilistic classif for example)
         * proba percentiles
         * conditional outputs
         * ...

        This logic is handled in the backend when creating the recipe. It's also handled in python depending on
        the code path followed according to the params of the recipe.
        In order not to duplicate the logic in python when needing an empty dataframe, we rely on the backend
        created schema.

        :param list(str) input_df_columns: list of input columns
        :param Schema    output_dataset_schema: Output dataset schema
        :return: empty pd.DataFrame with relevant cols
    """
    output_columns = [out["name"] for out in output_dataset_schema if "name" in out.keys()]
    created_columns = [col for col in output_columns if col not in input_df_columns]
    return pd.DataFrame(columns=created_columns)
