import json
from os import path as osp

import numpy as np
import pandas as pd
import tensorflow as tf

from dataiku.core import doctor_constants
from dataiku.doctor.deep_learning import keras_model_io_utils
from dataiku.doctor.deep_learning.keras_callbacks import DLModelTrainingInfoHandler
from dataiku.doctor.deep_learning.keras_support import build_scored_validation_data, get_scored_from_y_and_pred
from dataiku.doctor.deep_learning.shared_variables import set_variable
from dataiku.doctor.deep_learning.tfcompat._tf_imports_compat import Callback
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.classification_scoring import compute_optimized_threshold
from dataiku.doctor.prediction.common import get_grid_scorers
from dataiku.doctor.prediction.common import greater_is_better
from dataiku.doctor.prediction.custom_scoring import get_custom_evaluation_metric
from dataiku.doctor.prediction.metric import METRICS_NAMES
from dataiku.doctor.utils import unix_time_millis


def _compute_perf_and_save_best_model_callback_tf2(run_folder_context, modeling_params, validation_sequence,
                                                   prediction_type, target_map, save_model):
    return ComputePerfAndSaveBestModelCallBack(run_folder_context, modeling_params, validation_sequence,
                                               prediction_type, target_map, save_model)


def _compute_perf_and_save_best_model_callback_tf1(run_folder_context, modeling_params, validation_sequence,
                                                   prediction_type, target_map, save_model,
                                                   use_multi_gpus=False, base_model=None):
    return ComputePerfAndSaveBestModelCallBackTF1(run_folder_context, modeling_params, validation_sequence,
                                                  prediction_type, target_map, save_model,
                                                  use_multi_gpus=use_multi_gpus, base_model=base_model)


class ComputePerfAndSaveBestModelCallBack(Callback):

    def __init__(self, run_folder_context, modeling_params, validation_sequence, prediction_type,
                 target_map, save_model):
        """

        :param run_folder_context: Disk location where model data is saved
        :param modeling_params: User set modeling params
        :param validation_sequence: Test data to evaluate the model after each epoch
        :param prediction_type: REGRESSION, MULTICLASS or BINARY_CLASSIFICATION
        :param target_map: Remapping of class names between the classifier output and the dataset values
        :param save_model: Whether to save the best found model across epochs
        """
        self.run_folder_context = run_folder_context
        self.modeling_params = modeling_params
        self.validation_sequence = validation_sequence
        self.prediction_type = prediction_type
        self.target_map = target_map

        self.epoch_start_time = None
        self.all_scorers = get_grid_scorers(self.modeling_params, self.prediction_type, self.target_map,
                                            custom_make_scorer=self._scorer_func)
        self.model_best_score = None
        self.kept_model_epoch = -1
        self.save_model = save_model

        # Share the name of metric used to optimize model
        # The user can then retrieve it to write his own callback for example
        self.evaluation_metric = self.modeling_params['metrics']['evaluationMetric']
        set_variable("DKU_MODEL_METRIC",
                     "Test {}".format(METRICS_NAMES[self.evaluation_metric]))

        set_variable("DKU_MODEL_METRIC_GREATER_IS_BETTER",
                     greater_is_better(self.evaluation_metric,
                                       get_custom_evaluation_metric(
                                            self.modeling_params["metrics"]
                                        )["greaterIsBetter"] if self.evaluation_metric == "CUSTOM" else True))

        # Initialize model info
        self.model_training_info = {
            "startedAt": unix_time_millis(),
            "epochs": [],
            'metric': modeling_params["metrics"]["evaluationMetric"],
        }

        self.resolved_params = {
            "completedEpochs": 0,
            "keptModelEpoch": -1
        }

        self.train_info_handler = DLModelTrainingInfoHandler(self.run_folder_context)

    # Reuse logic of Grid Search scorer of Regular Python backend in order to leverage the list of
    # metrics available from the front-end
    def _scorer_func(self, score_func, needs_proba=False, greater_is_better=True, **kwargs):
        sign = 1 if greater_is_better else -1

        def score(y, y_pred, probas):
            if needs_proba:
                return sign * score_func(y, probas, **kwargs)
            else:
                return sign * score_func(y, y_pred, **kwargs)

        return score

    def _optimize_threshold(self, valid_y, probas, preds):
        # Optimize threshold for Binary Classification if required
        if self.prediction_type == doctor_constants.BINARY_CLASSIFICATION:

            optimize_threshold = self.modeling_params["autoOptimizeThreshold"]

            if optimize_threshold:
                decisions_and_cuts = DecisionsAndCuts.from_probas(probas, self.target_map)
                best_cut = compute_optimized_threshold(valid_y, decisions_and_cuts, self.modeling_params['metrics'])
                preds = (pd.Series(probas[:, 1]) > best_cut).astype(int)

        return preds

    def _compute_train_scores(self):
        return {}

    def _compute_test_scores(self):

        # While scoring, update number of steps done so far to fill epoch progress graph
        def on_step_end(step):
            self.model_training_info["currentNumStepsScoring"] += 1
            if step % 10 == 0 or step == (self.model_training_info["nbStepsScoringPerEpoch"] - 1):
                self._update_model_training_info()

        # For the validation set we enforce to score on all the data, because we cannot retrieve potential
        # validation_steps
        y_validation, probas_validation, true_y_validation, validation_loss = build_scored_validation_data(
                                                                                           self.model,
                                                                                           self.prediction_type,
                                                                                           self.modeling_params,
                                                                                           self.validation_sequence,
                                                                                           nb_steps=None,
                                                                                           on_step_end_func=on_step_end)
        y_validation = self._optimize_threshold(true_y_validation, probas_validation, y_validation)
        return validation_loss, {
            scorer_name: np.float64(scoring_func(true_y_validation, y_validation, probas_validation)) for
            scorer_name, scoring_func in self.all_scorers.items()}

    def _update_epoch_graph(self, train_score, test_score, train_loss, test_loss, epoch, kept_model_epoch):
        epoch_finish_time = unix_time_millis()

        new_point = {
            'time': epoch_finish_time - self.epoch_start_time,
            'index': epoch + 1,
            'trainScore': train_score,
            'testScore': test_score,
            'trainLoss': train_loss,
            'testLoss': test_loss,
            'epoch': epoch
        }
        self.model_training_info['epochs'].append(new_point)
        self.model_training_info['keptModelEpoch'] = kept_model_epoch
        self._update_model_training_info(force=True)

    def _update_model_training_info(self, force=False):
        self.train_info_handler.update_info(self.model_training_info, force=force)

    def _save_model(self):
        keras_model_io_utils.save_model(self.model, self.run_folder_context)

    def _get_model_architecture(self):
        return json.dumps(self.model.to_json())

    def on_train_begin(self, logs=None):
        self.modeling_params['keras']['epochs'] = self.params["epochs"]
        self.model_training_info["nbStepsTrainingPerEpoch"] = self.params["steps"]
        self.model_training_info["nbStepsScoringPerEpoch"] = len(self.validation_sequence)
        self.model_training_info["architecture"] = self._get_model_architecture()
        self.model_training_info["nbEpochs"] = self.params["epochs"]

    def on_train_end(self, logs=None):
        self.modeling_params["keras"]["completedEpochs"] = len(self.model_training_info["epochs"])
        self.modeling_params["keras"]["keptModelEpoch"] = self.model_training_info["keptModelEpoch"]

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = unix_time_millis()
        self.model_training_info["currentNumStepsTraining"] = 0
        self.model_training_info["currentNumStepsScoring"] = 0
        self.model_training_info["currentEpoch"] = epoch
        self._update_model_training_info(force=True)

    def on_batch_end(self, batch, logs=None):
        self.model_training_info["currentNumStepsTraining"] += 1
        self._update_model_training_info()

    def on_epoch_end(self, epoch, logs=None):
        """
        Loss computation note:
        We want to get values for all dataiku metrics and model loss on the validation set,
        and get what's available for the training set : on tf1 this will be all metrics and loss, on tf2 this will be only loss
        (since we can't hook into the training loop in the same way to compute metrics)

        This is the goal of the history object which is the returned value of model.fit(), so why not use that ?
        - We want to compute every metric defined through all_scorers independently of the architecture/training user code snippets
        - The list of metrics in history is the list provided to the call to model.compile(), which is in the snippets :-/
        So to use history, we'd have to either force those snippets to have a certain content, or modify them forcefully in the backend,
        both options being user hostile and bug prone

        We prefer not feeding a validation_sequence to the fit() call, and manually do a pass to predict() the whole sequence and at that point compute
        the metrics and the model loss based on predicted values. This is done in _compute_test_scores() which in turns calls build_scored_validation_data()
        """
        # The 4 data points to update the session graphs after each epoch
        train_score = train_loss = test_score = test_loss = None
        # Tracking of the best model to know which is the best epoch
        model_score = None
        # if test data is present:
        if len(self.validation_sequence) > 0:
            # First compute metric
            test_loss, test_scores = self._compute_test_scores()
            if test_scores:
                logs.update({"Test {}".format(METRICS_NAMES[k]): v for k, v in test_scores.items()})
                model_score = test_score = test_scores[self.modeling_params['metrics']['evaluationMetric']]
            logs["Test loss"] = test_loss
            if model_score is None:
                model_score = - test_loss

        train_scores = self._compute_train_scores()
        if train_scores:
            logs.update({"Train {}".format(METRICS_NAMES[k]): v for k, v in train_scores.items()})
            train_score = train_scores[self.modeling_params['metrics']['evaluationMetric']]
        train_loss = logs.get("loss")

        if model_score is None and train_loss is not None:
            model_score = - train_loss

        # Order of preference for evaluating models :
        # dss evaluation metric value preferred to
        # validation sequence model loss preferred to
        # train sequence model loss preferred to
        # every epoch is better than the previous

        better_model = model_score is None or \
                       model_score is not None and self.model_best_score is None or \
                       model_score > self.model_best_score

        if better_model:
            self.model_best_score = model_score
            self.kept_model_epoch = epoch
            if self.save_model:
                self._save_model()
        self._update_epoch_graph(train_score, test_score, train_loss, test_loss, epoch, self.kept_model_epoch)


class ComputePerfAndSaveBestModelCallBackTF1(ComputePerfAndSaveBestModelCallBack):

    def __init__(self, *args, **kwargs):
        self.use_multi_gpus = kwargs.pop('use_multi_gpus', False)
        self.base_model = kwargs.pop('base_model', None)
        super(ComputePerfAndSaveBestModelCallBackTF1, self).__init__(*args)

        # We want to compute the metrics on the training data as well. To do it in a Keras way
        # we retrieve, after each batch, the value of y and y_pred for this batch for the model at this
        # stage of the training, accumulate them and then compute the score and all the values retrieved during the
        # epoch. This means that it does not correspond exactly to the score on the training
        # data with a fixed model at the end of an epoch, but to the score of an evolving model.
        # Those values are stored in TensorFlow Variable in the model so we need to tell TensorFlow that we want to
        # to retrieve them

        # Variables to accumulate values of y and y_pred after each batch
        self.y_list = None
        self.y_pred_list = None

        # TensorFlow Variables that are placeholders for values of y and y_pred
        self.var_y = tf.Variable(0., validate_shape=False)
        self.var_y_pred = tf.Variable(0., validate_shape=False)

    def _compute_train_scores(self):

        y_train = np.concatenate(self.y_list)
        y_pred_train = np.concatenate(self.y_pred_list)
        preds_train, probas_train, valid_y_train_np = get_scored_from_y_and_pred(y_train, y_pred_train,
                                                                                 self.prediction_type,
                                                                                 self.modeling_params)
        true_y_train = pd.Series(valid_y_train_np)
        preds_train = self._optimize_threshold(true_y_train, probas_train, preds_train)

        return {scorer_name: np.float64(scoring_func(true_y_train, preds_train, probas_train)) for
                scorer_name, scoring_func in self.all_scorers.items()}

    def _save_model(self):
        if not self.use_multi_gpus:
            keras_model_io_utils.save_model(self.model, self.run_folder_context)
        else:
            keras_model_io_utils.save_model(self.base_model, self.run_folder_context)

    def _get_model_architecture(self):
        if not self.use_multi_gpus:
            return json.dumps(self.model.to_json())
        else:
            return json.dumps(self.base_model.to_json())

    def on_train_begin(self, logs=None):
        super(ComputePerfAndSaveBestModelCallBackTF1, self).on_train_begin(logs)
        # Telling TensorFlow which variables to retrieve after training each batch.
        # This needs to be done after compilation of the model and after the call the
        # the function _make_train_function which actually builds `model.train_function`
        # which is called at the beginning of fit
        fetches = [tf.assign(self.var_y, self.model.targets[0], validate_shape=False),
                   tf.assign(self.var_y_pred, self.model.outputs[0], validate_shape=False)]
        self.model.train_function.fetches = fetches

    def on_epoch_begin(self, epoch, logs=None):
        super(ComputePerfAndSaveBestModelCallBackTF1, self).on_epoch_begin(epoch, logs)
        # Reinitialize the accumulators of y and y_pred at the beginning of each epoch.
        self.y_list = []
        self.y_pred_list = []

    def on_batch_end(self, batch, logs=None):
        super(ComputePerfAndSaveBestModelCallBackTF1, self).on_batch_end(batch, logs)
        import keras.backend as K
        # Evaluate the variables and save them into the accumulators.
        self.y_list.append(K.eval(self.var_y))
        self.y_pred_list.append(K.eval(self.var_y_pred))
