import threading

from dataiku.core import doctor_constants
from dataiku.doctor.utils import unix_time_millis


##############################################################
# KERAS CALLBACKS
##############################################################

# Defining the callback class inside a function in order to execute the required import from Keras only if
# necessary (as the keras_utils script is imported in scripts that do not always run on Keras aware environments)


def _tensorboard_callback(tensorboard_folder_path):
    from keras.callbacks import TensorBoard
    return TensorBoard(log_dir=tensorboard_folder_path)


def get_base_callbacks(run_folder_context, modeling_params, validation_sequence,
                       prediction_type, target_map, save_model, tensorboard_folder_path, **kwargs):
    from dataiku.doctor.deep_learning.tfcompat import compute_perf_and_save_best_model_callback
    base_callbacks = [
        compute_perf_and_save_best_model_callback(run_folder_context, modeling_params, validation_sequence,
                                                  prediction_type, target_map, save_model, **kwargs),
        _interrupt_callback(run_folder_context),
        _tensorboard_callback(tensorboard_folder_path),
        _monitor_epochs_callback(run_folder_context, modeling_params)
    ]
    return base_callbacks


def _interrupt_callback(run_folder_context):
    from dataiku.doctor.deep_learning.tfcompat import Callback

    class InterruptCallback(Callback):

        def __init__(self, folder_context):
            self._folder_context = folder_context

        def on_epoch_end(self, epoch, logs=None):
            if self._folder_context.isfile(doctor_constants.STOP_SEARCH_FILENAME, allow_cached=False):
                self.model.stop_training = True

    return InterruptCallback(run_folder_context)


def _monitor_epochs_callback(run_folder_context, modeling_params):
    from dataiku.doctor.deep_learning.tfcompat import Callback

    class MonitorEpochsCallback(Callback):

        def __init__(self, folder_context, modeling_params):
            self.modeling_params = modeling_params
            self.last_finished_epoch = -1
            self.train_info_handler = DLModelTrainingInfoHandler(folder_context)

        def on_epoch_end(self, epoch, logs=None):
            self.last_finished_epoch = epoch

        def on_train_end(self, logs=None):
            nb_epochs = self.last_finished_epoch + 1
            self.modeling_params['keras']['epochs'] = nb_epochs
            model_info = self.train_info_handler.get_info()
            model_info["nbEpochs"] = nb_epochs
            self.train_info_handler.update_info(model_info, force=True)

    return MonitorEpochsCallback(run_folder_context, modeling_params)


class DLModelTrainingInfoHandler:
    """
    Handles the update of model information file on disk, to inform the front-end on the progress of the training
    """

    def __init__(self, folder_context, info_filename="keras_model_training_info", delay=2):

        self.folder_context = folder_context
        self.delay = delay
        self.info_filename = info_filename
        # this should not be necessary, but this object is passed to a 3rd party library, so we are extra careful
        self._lock = threading.Lock()

        self.last_updated = unix_time_millis()

    def should_update(self):
        curr_time = unix_time_millis()
        if (curr_time - self.last_updated) > self.delay * 1000:
            self.last_updated = curr_time
            return True

        return False

    def update_info(self, new_info, force=False):
        if force or self.should_update():
            with self._lock:
                self.folder_context.write_json("{}.json".format(self.info_filename), new_info)

    def get_info(self):
        return self.folder_context.read_json("{}.json".format(self.info_filename))
