import logging
import sys

from dataiku.base.utils import package_is_at_least
from dataiku.doctor.deep_learning import custom_objects_handler

logger = logging.getLogger(__name__)

KERAS_MODEL_FILENAME_H5 = "keras_model.h5"
KERAS_MODEL_FILENAME_KERAS = "keras_model.keras"

import keras
if package_is_at_least(keras, "3"):
    KERAS_MODEL_FILENAME = KERAS_MODEL_FILENAME_KERAS
else:
    KERAS_MODEL_FILENAME = KERAS_MODEL_FILENAME_H5

def save_model(model, model_folder_context):
    custom_objects_handler.save_current_custom_objects(model_folder_context)
    with model_folder_context.get_file_path_to_write(KERAS_MODEL_FILENAME) as model_path:
        model.save(safe_model_location(model_path), overwrite=True)


def load_model(model_folder_context):
    from dataiku.doctor.deep_learning.tfcompat import keras_load_model
    custom_objects = custom_objects_handler.load_custom_objects(model_folder_context)
    try:
        with model_folder_context.get_file_path_to_read(KERAS_MODEL_FILENAME) as model_path:
            return keras_load_model(safe_model_location(model_path), custom_objects=custom_objects)
    except ValueError:
        # fallback to the legacy model filename, useful for models trained before Keras 3 then loaded on Keras 3 and up
        if KERAS_MODEL_FILENAME == KERAS_MODEL_FILENAME_KERAS:
            logger.warning("failed to load a .keras model, loading a .h5 model instead")

            # this is required to let Keras find the functions even if the name didn't change
            # https://github.com/keras-team/keras/issues/19330#issuecomment-2163993015
            if not custom_objects:
                custom_objects = {}
            custom_objects.update({
                # https://github.com/keras-team/tf-keras/blob/v2.19.0/tf_keras/metrics/__init__.py#L123
                m: m for m in ["acc", "bce", "mse", "mae", "mape", "msle", "log_cosh", "cosine_proximity"]
            })

            with model_folder_context.get_file_path_to_read(KERAS_MODEL_FILENAME_H5) as model_path:
                return keras_load_model(safe_model_location(model_path), custom_objects=custom_objects)


# From Keras 2.2.3, it is mandatory to pass a 'str' to save and load_model
# otherwise, it fails. For python 2, we may have some unicode that needs to
# be converted
def safe_model_location(model_location):
    if sys.version_info < (3, 0) and isinstance(model_location, unicode):
        return model_location.encode("utf-8")
    else:
        return model_location
