import logging

import numpy as np
import pandas as pd

from dataiku.core import doctor_constants
from dataiku import default_project_key

from dataiku.base.utils import safe_unicode_str, package_is_at_least
from dataiku.doctor.deep_learning import keras_model_io_utils
from dataiku.doctor.deep_learning.keras_utils import split_train_per_input
from dataiku.doctor.deep_learning.keras_utils import retrieve_func_from_code_dict
from dataiku.doctor.deep_learning.keras_utils import tag_special_features
from dataiku.doctor.exception import EmptyDatasetException
from dataiku.doctor.prediction.classification_scoring import binary_classif_scoring_add_percentile_and_cond_outputs
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils import normalize_dataframe
from dataiku.doctor.utils.api_logs import API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_EVALUATION_DATASET_TYPE
from dataiku.doctor.utils.api_logs import normalize_api_node_logs_dataset
from dataiku.doctor.utils.api_logs import CLASSICAL_EVALUATION_DATASET_TYPE
from dataiku.doctor.utils.gpu_execution import KerasGPUCapability, get_gpu_config_from_recipe_desc
from dataiku.doctor.utils.scoring_recipe_utils import add_output_model_metadata
from dataiku.doctor.utils.scoring_recipe_utils import dataframe_iterator
from dataiku.doctor.utils.scoring_recipe_utils import get_dataframe_dtypes_info

logger = logging.getLogger(__name__)

##############################################################
# BUILDING MODEL FROM USER CODE
##############################################################


def get_best_model(train_normal_X, train_df, pipeline, test_df, per_feature, modeling_params, model_folder_context,
                   prediction_type, target_map, generated_features_mapping, gpu_config, save_model=True):
    from dataiku.doctor.deep_learning import gpu
    from dataiku.doctor.deep_learning.sequences import InputsDataWithTargetSequence
    from dataiku.doctor.deep_learning.tfcompat import build_and_fit_model

    train_normal_dict_np = split_train_per_input(train_normal_X, per_feature, generated_features_mapping)
    input_shapes = {}
    for k in train_normal_dict_np.keys():
        input_shapes[k] = (train_normal_dict_np[k].shape[1],)

    # User needs to know number of classes in order to build appropriate network
    if prediction_type in [doctor_constants.MULTICLASS, doctor_constants.BINARY_CLASSIFICATION]:
        output_num_labels = len(target_map)
    else:
        output_num_labels = 1

    keras_params = modeling_params["keras"]

    # Set GPU options if required
    use_gpu = KerasGPUCapability.should_use_gpu(gpu_config)

    if use_gpu:
        gpu_params = gpu_config["params"]
        gpu.load_gpu_options(gpu_params["gpuList"],
                             allow_growth=gpu_params["gpuAllowGrowth"],
                             per_process_gpu_memory_fraction=float(gpu_params["perGPUMemoryFraction"]))
    else:
        gpu.deactivate_gpu()

    # Retrieve/Build functions to build and train Keras architecture
    assert keras_params.get('buildCode', False)
    build_code = keras_params["buildCode"]
    dic_build = {}
    exec(build_code, dic_build, dic_build)

    if keras_params["advancedFitMode"]:
        train_code = keras_params["fitCode"]
        # Will define fit model function and activate/deactivate GPU if required
        dic_fit = {}
        exec(train_code, dic_fit, dic_fit)

        fit_model = retrieve_func_from_code_dict("fit_model", dic_fit, "Training")
        build_sequences = retrieve_func_from_code_dict("build_sequences", dic_fit, "Training")
    else:
        fit_model = None
        build_sequences = None

    train_sequence_builder = InputsDataWithTargetSequence.get_sequence_builder(
        prediction_type=prediction_type,
        input_df=train_df,
        pipeline=pipeline,
        per_feature=per_feature,
        generated_features_mapping=generated_features_mapping,
        modeling_params=modeling_params,
        target_map=target_map,
        name="train",
    )

    validation_sequence_builder = InputsDataWithTargetSequence.get_sequence_builder(
        prediction_type=prediction_type,
        input_df=test_df,
        pipeline=pipeline,
        per_feature=per_feature,
        generated_features_mapping=generated_features_mapping,
        modeling_params=modeling_params,
        target_map=target_map,
        name="validation"
    )

    # Building sequences
    if keras_params["advancedFitMode"]:
        train_sequence, validation_sequence = build_sequences(train_sequence_builder, validation_sequence_builder)
    else:
        batch_size = keras_params["batchSize"]
        train_sequence = train_sequence_builder(batch_size)
        validation_sequence = validation_sequence_builder(batch_size)

    build_and_fit_model(input_shapes, output_num_labels, prediction_type, target_map, modeling_params, keras_params,
                        dic_build, model_folder_context, validation_sequence, save_model, fit_model,
                        train_sequence, gpu_config)

    model = keras_model_io_utils.load_model(model_folder_context)

    return model, validation_sequence

##############################################################
# SCORING MODEL
##############################################################

def _update_loss_counters(computed_loss, summed_loss, nb_summed_values):
    # Depending on loss function and tf1 or tf2 envs, calling loss_function over a batch will return either a single value
    # or a value per sample in the batch, so we handle both cases.
    summed_loss += sum(computed_loss) if computed_loss.shape != () else computed_loss
    nb_summed_values += computed_loss.shape[0] if computed_loss.shape != () else 1
    return summed_loss, nb_summed_values


def build_scored_validation_data(model, prediction_type, modeling_params, valid_iterator, nb_steps=None,
                                 on_step_end_func=None):
    import tensorflow
    import numpy
    if package_is_at_least(numpy, "2"):
        numpy_unicode = numpy.str_
    elif package_is_at_least(numpy, "1.24"):
        numpy_unicode = numpy.unicode_
    else:
        numpy_unicode = numpy.unicode
    from dataiku.doctor.deep_learning.tfcompat import get_loss

    if nb_steps is None:
        nb_steps = len(valid_iterator)
    probas_list = []
    preds_list = []
    valid_y_list = []
    summed_loss = 0.0
    loss_values_summed = 0  # may be different than len(valid_iterator) in the end since in some cases there is one loss per batch

    # Supports different syntaxes for loss in compile_model for both tf1 and tf2:
    # loss="categorical_crossentropy", loss="CategoricalCrossentropy" (valid for tf2 only) or loss=keras.losses.CategoricalCrossentropy()
    if isinstance(model.loss, str) or isinstance(model.loss, numpy_unicode):
        loss_function = get_loss(model.loss)
    else:
        loss_function = model.loss
    # We're going to have a special tf1 treatment to avoid hitting https://github.com/keras-team/keras/issues/12916
    # My understanding is that we "pin" the loss function in the tensorflow graph thanks to the placeholders, avoiding the
    # later computations on it producing thousands of graph versions, which leads to each call to get_session() (done in eval() )
    # to take more and more time.
    # We can't really (or I didn't find how) move all this in nice little functions, since the references for the keys in feed_dict must be
    # accessible when we need them to eval, and we can't really have "create_eval_loss" function with a different return signature for TF1 and TF2
    # So instead of hiding all of that in tfcompat package, it's here :(
    is_tf2 = package_is_at_least(tensorflow, "2.2")
    if not is_tf2:
        from keras import backend as K
        y_true_ph = K.placeholder()
        y_pred_ph = K.placeholder()
        loss_function = loss_function(y_true_ph, y_pred_ph)
    def compute_loss(y_true, y_pred):
        import keras
        if package_is_at_least(keras, "3"):
            # binary_crossentropy (at least) requires both arrays to be of the same shape in Keras 3
            if y_true.shape != y_pred.shape:
                y_true = y_true.reshape(y_pred.shape)
        return loss_function(y_true, y_pred).numpy() if is_tf2 else K.get_session().run(loss_function,
                                                                                        feed_dict={y_true_ph: y_true,
                                                                                                   y_pred_ph: y_pred})

    for num_batch in range(nb_steps):

        (X, y) = valid_iterator[num_batch]

        if prediction_type == doctor_constants.REGRESSION:
            valid_y_list.append(y)
            y_pred = model.predict(X)
            summed_loss, loss_values_summed = _update_loss_counters(compute_loss(y, y_pred), summed_loss,
                                                                    loss_values_summed)
            preds_list.append(np.squeeze(y_pred, axis=1))
        elif prediction_type == doctor_constants.BINARY_CLASSIFICATION and modeling_params["keras"]["oneDimensionalOutput"]:
            valid_y_list.append(y)
            probas_one_raw = np.squeeze(model.predict(X), axis=1)
            summed_loss, loss_values_summed = _update_loss_counters(compute_loss(y, probas_one_raw), summed_loss,
                                                                    loss_values_summed)
            probas_raw = np.zeros((probas_one_raw.shape[0], 2))
            probas_raw[:, 1] = probas_one_raw
            probas_raw[:, 0] = 1 - probas_one_raw
            probas_list.append(probas_raw)
            preds_list.append((probas_one_raw > 0.5).astype(int))
        else:
            # i.e. for MULTICLASS and BINARY CLASSIF with 2-dimensional output
            valid_y_list.append(np.argmax(y, axis=1))
            probas_raw = model.predict(X)
            summed_loss, loss_values_summed = _update_loss_counters(compute_loss(y, probas_raw), summed_loss,
                                                                    loss_values_summed)
            probas_list.append(probas_raw)
            preds_list.append(np.argmax(probas_raw, axis=1))

        if on_step_end_func is not None:
            on_step_end_func(num_batch)

    valid_y_as_np = np.concatenate(valid_y_list)
    preds = np.concatenate(preds_list)
    if prediction_type != doctor_constants.REGRESSION:
        valid_y_as_np = valid_y_as_np.astype(int)
        probas = np.concatenate(probas_list)
    else:
        probas = None
    valid_y = pd.Series(valid_y_as_np)
    return preds, probas, valid_y, summed_loss / loss_values_summed

def get_scored_from_y_and_pred(y, y_pred, prediction_type, modeling_params):
    probas = None
    if prediction_type == doctor_constants.REGRESSION:
        valid_y = np.squeeze(y, axis=1)
        preds = np.squeeze(y_pred, axis=1)
    elif prediction_type == doctor_constants.BINARY_CLASSIFICATION and modeling_params["keras"]["oneDimensionalOutput"]:
        valid_y = np.squeeze(y, axis=1)
        probas_one = np.squeeze(y_pred, axis=1)
        probas = np.zeros((probas_one.shape[0], 2))
        probas[:, 1] = probas_one
        probas[:, 0] = 1 - probas_one
        preds = (probas_one > 0.5).astype(int)
    else:
        # i.e. for MULTICLASS and BINARY CLASSIF with 2-dimensional output
        valid_y = np.argmax(y, axis=1)
        probas = y_pred
        preds = np.argmax(probas, axis=1)

    return preds, probas, valid_y

##############################################################
# SCORING GENERATOR FOR SCORING/EVALUATION RECIPES
##############################################################


def scored_dataset_generator(model_folder_context, input_dataset, recipe_desc, script, preparation_output_schema,
                             cond_outputs, output_y=False, output_input_df=False, fmi=None,
                             evaluation_dataset_type=CLASSICAL_EVALUATION_DATASET_TYPE,
                             filter_input_columns=True):
    from dataiku.doctor.deep_learning import gpu

    # Load GPU Options

    gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
    use_gpu = KerasGPUCapability.should_use_gpu(gpu_config)

    if use_gpu:
        gpu_params = gpu_config["params"]
        gpu.load_gpu_options(gpu_params["gpuList"],
                             allow_growth=gpu_params["gpuAllowGrowth"],
                             per_process_gpu_memory_fraction=float(gpu_params["perGPUMemoryFraction"]))
    else:
        gpu.deactivate_gpu()

    batch_size = recipe_desc.get("batchSize", 100)
    sampling = recipe_desc.get("selection", {"samplingMethod":"FULL"})

    # Obtain a streamed result of the preparation
    logger.info("Will do preparation, output schema: %s" % preparation_output_schema)
    preparation_steps = []
    if input_dataset.preparation_steps:
        preparation_steps += input_dataset.preparation_steps
    if script["steps"]:
        preparation_steps += script["steps"]
    input_dataset.set_preparation_steps(preparation_steps, preparation_output_schema,
                                        context_project_key=default_project_key())

    core_params = model_folder_context.read_json("core_params.json")
    preprocessing_params = model_folder_context.read_json("rpreprocessing_params.json")
    collector_data = model_folder_context.read_json("collector_data.json")
    resolved_params = model_folder_context.read_json("actual_params.json")["resolved"]

    prediction_type = core_params["prediction_type"]

    # Tagging special features to take them into account only in special_preproc_handler/special_pipeline
    per_feature = preprocessing_params["per_feature"]
    tag_special_features(per_feature)

    preproc_handler = PredictionPreprocessingHandler.build(core_params,
                                                           preprocessing_params,
                                                           model_folder_context)
    preproc_handler.collector_data = collector_data
    pipeline = preproc_handler.build_preprocessing_pipeline(with_target=output_y)
    target_map = preproc_handler.target_map

    logger.info("Loading model")
    model = keras_model_io_utils.load_model(model_folder_context)

    logger.info("Start output generator")

    names, dtypes, parse_date_columns = get_dataframe_dtypes_info(
        preparation_output_schema, preprocessing_params["per_feature"],
        prediction_type=prediction_type
    )

    # normalize=False because the normalization is done later and doesn't always use the normalize_dataframe method
    for input_df, input_df_copy_unnormalized in dataframe_iterator(
        input_dataset, names, dtypes, parse_date_columns, preprocessing_params["per_feature"],
        batch_size=batch_size, sampling=sampling, normalize=False,
    ):
        if input_df.empty:
            raise EmptyDatasetException("The input dataset can not be empty. Check the input or the recipe sampling configuration.")

        if evaluation_dataset_type in [API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_EVALUATION_DATASET_TYPE]:
            input_df = normalize_api_node_logs_dataset(input_df, preprocessing_params['per_feature'], evaluation_dataset_type)
        elif evaluation_dataset_type == CLASSICAL_EVALUATION_DATASET_TYPE:
            normalize_dataframe(input_df, preprocessing_params['per_feature'])
        else:
            raise ValueError("Dataset handled as %s is not supported for keras modes" % evaluation_dataset_type)

        for col in input_df:
            logger.info("NORMALIZED: %s -> %s" % (col, input_df[col].dtype))

        logger.info("Processing chunk")

        transformed = pipeline.process(input_df)
        features_X_orig = transformed["TRAIN"]
        transformed_X_mf = transformed["TRAIN"]

        inputs_dict = split_train_per_input(transformed_X_mf, per_feature, pipeline.generated_features_mapping)

        if prediction_type in [doctor_constants.MULTICLASS, doctor_constants.BINARY_CLASSIFICATION]:

            inv_map = {
                int(class_id): label
                for label, class_id in target_map.items()
            }
            classes = [class_label for (_, class_label) in sorted(inv_map.items())]

            if prediction_type == doctor_constants.MULTICLASS:
                probas_raw = model.predict(inputs_dict)
                preds = np.argmax(probas_raw, axis=1)

            if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                if resolved_params["keras"]["oneDimensionalOutput"]:
                    probas_one = np.squeeze(model.predict(inputs_dict), axis=1)
                    probas_raw = np.zeros((probas_one.shape[0], 2))
                    probas_raw[:, 1] = probas_one
                    probas_raw[:, 0] = 1 - probas_one
                else:
                    probas_raw = model.predict(inputs_dict)
                    probas_one = probas_raw[:, 1]

                threshold = recipe_desc["forcedClassifierThreshold"]
                preds = (probas_one > threshold).astype(int)

            (nb_rows, nb_present_classes) = probas_raw.shape
            logger.info("Probas raw shape %s/%s target_map=%s", nb_rows, nb_present_classes, len(target_map))

            preds_remapped = np.zeros(preds.shape, dtype="object")
            for (mapped_value, original_value) in inv_map.items():
                idx = (preds == mapped_value)
                preds_remapped[idx] = original_value
            pred_df = pd.DataFrame({"prediction": preds_remapped})
            pred_df.index = features_X_orig.index

            proba_cols = [u"proba_{}".format(safe_unicode_str(c)) for c in classes]
            # For Binary Classification: Must compute probas if conditional there are outputs that use them
            # Will be deleted afterwards (if outputProbabilities if False)
            # in binary_classif_scoring_add_percentile_and_cond_outputs
            probas_in_cond_outputs = (cond_outputs and len([co for co in cond_outputs
                                                            if co["input"] in proba_cols]) > 0)
            use_probas = recipe_desc["outputProbabilities"] or probas_in_cond_outputs
            if use_probas:
                proba_df = pd.DataFrame(probas_raw, columns=proba_cols)
                proba_df.index = features_X_orig.index
                pred_df = pd.concat([proba_df, pred_df], axis=1)

            if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                pred_df = binary_classif_scoring_add_percentile_and_cond_outputs(pred_df,
                                                                                 recipe_desc,
                                                                                 model_folder_context,
                                                                                 cond_outputs,
                                                                                 target_map)

        elif prediction_type == doctor_constants.REGRESSION:
            preds = model.predict(inputs_dict)
            pred_df = pd.DataFrame({"prediction": np.squeeze(preds, axis=1)})
            pred_df.index = features_X_orig.index

        logger.info("Done predicting it")
        if filter_input_columns and recipe_desc.get("filterInputColumns", False):
            clean_kept_columns = [c for c in recipe_desc["keptInputColumns"] if c not in pred_df.columns]
        else:
            clean_kept_columns = [c for c in input_df_copy_unnormalized.columns if c not in pred_df.columns]

        if recipe_desc.get("outputModelMetadata", False):
            add_output_model_metadata(pred_df, fmi)

        res = {
            "scored": pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1),
            "pred_df": pred_df
        }

        if output_y:
            res["y"] = transformed["target"].reindex(input_df_copy_unnormalized.index) # for use in computing error columns
            res["y_notnull"] = transformed["target"]

        if output_input_df:
            res["input_df"] = input_df_copy_unnormalized

        yield res
