import logging
import numpy as np
import pandas as pd

from dataiku.base.utils import safe_unicode_str
from dataiku.core.dataset import Dataset
from dataiku.core.dku_pandas_csv import pandas_date_parser_compat
from dataiku.doctor import utils
from dataiku.doctor.preprocessing.assertions import compute_assertions_masks
from dataiku.doctor.diagnostics import diagnostics


logger = logging.getLogger(__name__)


def df_from_split_desc_no_normalization(split_desc, split, split_folder_context, feature_params, prediction_type=None):
    if split_desc["format"] != "csv1":
        raise Exception("Unsupported format")

    if split == "full":
        f = split_desc["fullPath"]
    else:
        f = split == "train" and split_desc["trainPath"] or split_desc["testPath"]

    with split_folder_context.get_file_path_to_read(f) as split_file_path:
        return load_df_no_normalization(split_file_path, split_desc["schema"], feature_params, prediction_type)


def load_df_no_normalization(filepath, schema, feature_params, prediction_type):
    (names, dtypes, parse_date_columns) = Dataset.get_dataframe_schema_st(
        schema["columns"], parse_dates=True, infer_with_pandas=True)
    logging.info("Reading with dtypes: %s" % dtypes)
    dtypes = utils.ml_dtypes_from_dss_schema(schema,
                                             feature_params,
                                             prediction_type=prediction_type)

    # We infer everything with Pandas, EXCEPT booleans.
    # Because then pandas completely looses the original syntax
    # So for example if target is true/false, and we let pandas infer, then it will become
    # True/False, and when we remap, we try to remap with true/false and end up with no
    # target at all
    # for col in split_desc["schema"]["columns"]:
    #     if col["type"] == "boolean":
    #         if dtypes is None:
    #             dtypes = {}
    #         dtypes[col["name"]] = "str"
    logging.info("Reading with FIXED dtypes: %s" % dtypes)
    df = pd.read_table(filepath,
                       names=names,
                       dtype=dtypes,
                       header=None,
                       sep='\t',
                       doublequote=True,
                       quotechar='"',
                       parse_dates=parse_date_columns,
                       float_precision="round_trip")

    # used because pandas>0.23 doesn't automatically convert into UTC, also see https://github.com/pandas-dev/pandas/issues/50601
    df = pandas_date_parser_compat(df, parse_date_columns, lambda col: pd.to_datetime(col, utc=True))

    logging.info("Loaded table")
    return df


def load_df_with_normalization(filename, folder_context, schema, feature_params, prediction_type):
    with folder_context.get_file_path_to_read(filename) as file_path:
        df = load_df_no_normalization(file_path, schema, feature_params, prediction_type)
    return utils.normalize_dataframe(df, feature_params)


def df_from_split_desc(split_desc, split, split_folder_context, feature_params, prediction_type=None, assertions=None):
    df = df_from_split_desc_no_normalization(split_desc, split, split_folder_context, feature_params, prediction_type)
    if assertions is not None and len(assertions) > 0:
        assertions_masks = compute_assertions_masks(assertions, df)
        df = pd.concat([df, assertions_masks], axis=1)
    return utils.normalize_dataframe(df, feature_params)


def input_columns(per_feature):
    return [feature_name for feature_name, feature_details in per_feature.items()
            if feature_details["role"] == "INPUT"]


def load_train_set(core_params, preprocessing_params, split_desc, name, split_folder_context, assertions=None,
                   use_diagnostics=True):
    train_df = df_from_split_desc(split_desc, name, split_folder_context, preprocessing_params['per_feature'],
                                  core_params["prediction_type"], assertions=assertions)
    if use_diagnostics:
        diagnostics.on_load_train_dataset_end(prediction_type=core_params["prediction_type"], df=train_df, target_variable=core_params["target_variable"])
    logger.info("Loaded train df: shape=(%d,%d)" % train_df.shape)
    return train_df


def load_test_set(core_params, preprocessing_params, split_desc, split_folder_context, assertions=None, use_diagnostics=True):
    test_df = df_from_split_desc(split_desc, "test", split_folder_context, preprocessing_params["per_feature"],
                                 core_params["prediction_type"], assertions=assertions)
    if use_diagnostics:
        diagnostics.on_load_test_dataset_end(prediction_type=core_params["prediction_type"], df=test_df, target_variable=core_params['target_variable'])
    logger.info("Loaded test df: shape=(%d,%d)" % test_df.shape)
    return test_df


def check_train_test_order(train_df, test_df, time_variable, ascending):
    time_train_arr = train_df[time_variable].values
    time_test_arr = test_df[time_variable].values
    if np.issubdtype(time_train_arr.dtype, np.number):
        if np.any(np.isnan(time_train_arr)):
            raise ValueError("Train set should have no empty or NaN values " +
                             "for time variable '{}'".format(safe_unicode_str(time_variable)))
    if ascending:
        max_train = time_train_arr[-1]
        min_test = np.min(time_test_arr)
        if max_train > min_test:
            raise ValueError("Test set should have values greater or equal to all values of train set " +
                             "(max train = {max_train}, min test = {min_test})".format(max_train=max_train,
                                                                                       min_test=min_test))
    else:
        min_train = time_train_arr[0]
        max_test = np.max(time_test_arr)
        if max_test > min_train:
            raise ValueError("Test set should have values lower or equal to all values of train set " +
                             "(min train = {min_train}, max test = {max_test})".format(min_train=min_train,
                                                                                       max_test=max_test))


def check_dataframe_sorted(df, by_variable, ascending):
    series = df[by_variable]
    if ascending:
        return series.is_monotonic_increasing
    else:
        return series.is_monotonic_decreasing


def sort_dataframe(df, by_variable, ascending):
    if not check_dataframe_sorted(df, by_variable, ascending):
        logger.info(u"Dataframe not sorted, sorting by '{column}', ascending={ascending}".format(column=safe_unicode_str(by_variable), ascending=ascending))
        df.sort_values(by=by_variable, inplace=True, ascending=ascending)
