import dataiku
from dataiku.base import remoterun
from dataiku.base.folder_context import FmiReadonlyFolderContexts
from dataiku.base.folder_context import build_saved_model_folder_context
from dataiku.core import doctor_constants
from dataiku.core.intercom import backend_json_call
from dataiku.doctor.deep_learning import keras_model_io_utils


def get_keras_model_from_trained_model(session_id=None, analysis_id=None, mltask_id=None):
    model_folder_context = _get_keras_model_folder_context_from_trained_model(session_id, analysis_id, mltask_id)

    if not model_folder_context.isfile(keras_model_io_utils.KERAS_MODEL_FILENAME):
        raise ValueError("No model found for this mltask. Did it run without errors ?")

    return keras_model_io_utils.load_model(model_folder_context)


def get_keras_model_location_from_trained_model(session_id=None, analysis_id=None, mltask_id=None):
    model_folder_context = _get_keras_model_folder_context_from_trained_model(session_id, analysis_id, mltask_id)

    if not model_folder_context.isfile(keras_model_io_utils.KERAS_MODEL_FILENAME):
        raise ValueError("No model found for this mltask. Did it run without errors ?")

    with model_folder_context.get_file_path_to_read(keras_model_io_utils.KERAS_MODEL_FILENAME) as model_path:
        model_location = keras_model_io_utils.safe_model_location(model_path)
        return model_location


def get_keras_model_from_saved_model(saved_model_id):
    model_folder_context = _get_keras_model_folder_context_from_saved_model(saved_model_id)

    if not model_folder_context.isfile(keras_model_io_utils.KERAS_MODEL_FILENAME):
        raise ValueError("No model found for this saved model.")

    return keras_model_io_utils.load_model(model_folder_context)


def get_keras_model_location_from_saved_model(saved_model_id):
    model_folder_context = _get_keras_model_folder_context_from_saved_model(saved_model_id)

    if not model_folder_context.isfile(keras_model_io_utils.KERAS_MODEL_FILENAME):
        raise ValueError("No model found for this saved model.")

    with model_folder_context.get_file_path_to_read(keras_model_io_utils.KERAS_MODEL_FILENAME) as model_path:
        model_location = keras_model_io_utils.safe_model_location(model_path)
        return model_location


def _get_keras_model_folder_context_from_trained_model(session_id=None, analysis_id=None, mltask_id=None):
    analysis_id = _get_variable_value(analysis_id, "analysis_id", doctor_constants.DKU_CURRENT_ANALYSIS_ID)
    mltask_id = _get_variable_value(mltask_id, "mltask_id", doctor_constants.DKU_CURRENT_MLTASK_ID)

    # Retrieve info on location of model
    project_key = remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY")
    mltask = dataiku.api_client().get_project(project_key).get_ml_task(analysis_id, mltask_id)
    mltask_status = mltask.get_status()

    # Check good backend
    if mltask_status["headSessionTask"]["backendType"] != "KERAS":
        raise ValueError("The mltask you are accessing was not a Keras model")

    # We assume here that there is only one model per session, i.e. session_id are unique
    # in mltask_status["fullModelIds"], which is the case for KERAS backend
    sessions = [p["fullModelId"]["sessionId"] for p in mltask_status["fullModelIds"]]
    if session_id is None:
        last_session = sorted([int(sess_id_str[1:]) for sess_id_str in sessions])[-1]
        session_id = "s{}".format(last_session)
    try:
        session_index = sessions.index(session_id)
    except ValueError as e:
        raise ValueError("The 'session_id' you are providing cannot be found in the mltask. "
                         "Available session_ids are: {}".format(sessions))

    fmi = mltask_status["fullModelIds"][session_index]["id"]
    fmi_folder_contexts = FmiReadonlyFolderContexts.build(fmi)
    return fmi_folder_contexts.model_folder_context


def _get_keras_model_folder_context_from_saved_model(saved_model_id):
    project_key = remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY")
    active_version_id = dataiku.api_client().get_project(project_key) \
        .get_saved_model(saved_model_id) \
        .get_active_version()["id"]

    sm_details = backend_json_call("savedmodels/get-model-details", data={
        "projectKey": project_key,
        "smId": saved_model_id,
        "versionId": active_version_id
    })

    model_folder = sm_details["model_folder"]
    return build_saved_model_folder_context(model_folder, project_key, saved_model_id)


def _get_variable_value(variable, variable_name, os_variable_name):
    if variable is None:

        if not remoterun.get_env_var(os_variable_name, False):
            raise ValueError("You must provide an '{}' argument".format(variable_name))
        else:
            return remoterun.get_env_var(os_variable_name)

    return variable