import json
import logging
import traceback
import yaml

from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import watch_stdin, get_json_friendly_error
from dataiku.core import debugging
from dataiku.external_ml.mlflow.pyfunc_read_meta import get_mlflow_dss_mappings
from dataikuapi.dss.savedmodel import DSSSavedModel, DatabricksRepositoryContextManager

logger = logging.getLogger(__name__)

class RequestBase(object):
    def __init__(self, data):
        self._data = data

    @property
    def connection_info(self):
        return self._data["connectionInfo"]

    @property
    def use_unity_catalog(self):
        return self._data["useUnityCatalog"]


class RequestDownload(RequestBase):
    def __init__(self, data):
        super(RequestDownload, self).__init__(data)

    @property
    def model_name(self):
        return self._data["modelName"]

    @property
    def model_version(self):
        return self._data["modelVersion"]

    @property
    def target_directory(self):
        return self._data["targetDirectory"]


class ListRegisteredModels(RequestBase):
    def __init__(self, data):
        super(ListRegisteredModels, self).__init__(data)


class ListRegisteredModelVersions(RequestBase):
    def __init__(self, data):
        super(ListRegisteredModelVersions, self).__init__(data)

    @property
    def model_name(self):
        return self._data["modelName"]


class RequestModelRegistration(RequestBase):
    def __init__(self, data):
        super(RequestModelRegistration, self).__init__(data)

    @property
    def model_name(self):
        return self._data["modelName"]

    @property
    def model_directory(self):
        return self._data["modelDirectory"]

    @property
    def experiment_name(self):
        return self._data["experimentName"]

    @property
    def columns(self):
        return self._data.get("columns")

    @property
    def target_name(self):
        return self._data.get("targetName")


class MLflowUtilsProtocol(object):
    def __init__(self, link):
        self.link = link

    def _handle_request_download(self, params):
        with DatabricksRepositoryContextManager(params.connection_info, params.use_unity_catalog):
            DSSSavedModel._download_from_databricks_registry(model_name=params.model_name,
                                                             model_version=params.model_version,
                                                             target_directory=params.target_directory)
        self.link.send_json({"type": "RequestDownloadResponse"})

    def _handle_list_registered_models(self, params):
        with DatabricksRepositoryContextManager(params.connection_info, params.use_unity_catalog):
            import mlflow
            client = mlflow.client.MlflowClient()
            reg_models = client.search_registered_models()
            ret = []
            for reg_model in reg_models:
                ret.append({
                    "creation_timestamp": reg_model.creation_timestamp,
                    "last_updated_timestamp": reg_model.last_updated_timestamp,
                    "description": reg_model.description,
                    "name": reg_model.name
                })
            self.link.send_json({"type": "ListRegisteredModelsResponse", "result": ret})

    def _handle_list_registered_model_versions(self, params):
        with DatabricksRepositoryContextManager(params.connection_info, params.use_unity_catalog):
            import mlflow
            client = mlflow.client.MlflowClient()
            # We need to manually escape single quotes
            reg_versions = client.search_model_versions("name='{}'".format(params.model_name.replace("'", "''")))
            ret = []
            for reg_version in reg_versions:
                ret.append({
                    "creation_timestamp": reg_version.creation_timestamp,
                    "last_updated_timestamp": reg_version.last_updated_timestamp,
                    "description": reg_version.description,
                    "name": reg_version.name,
                    "run_id": reg_version.run_id,
                    "source": reg_version.source,
                    "status": reg_version.status,
                    "user_id": reg_version.user_id,
                    "version": reg_version.version
                })
            self.link.send_json({"type": "ListRegisteredModelVersionsResponse", "result": ret})

    def _handle_register_model(self, params):
        with DatabricksRepositoryContextManager(params.connection_info, params.use_unity_catalog):
            import os
            import mlflow
            from mlflow.models.model import MLMODEL_FILE_NAME
            from mlflow.models.signature import ModelSignature
            from mlflow.types.schema import ColSpec
            from mlflow.types.schema import Schema

            # first, load the MLmodel file
            mlmodel_path = os.path.join(params.model_directory, MLMODEL_FILE_NAME)
            mlmodel = self._safe_load_mlmodel(mlmodel_path)

            conda_yaml_path = os.path.join(params.model_directory, "conda.yaml")
            model_python_version = self._get_python_version_of_model(conda_yaml_path)

            # we take special care of model signature as it is required if exporting in the Unity Catalog
            if not mlmodel.signature:
                if not params.columns or not params.target_name:
                    logger.info("Model has no schema or no target defined. Can not generate a signature")
                else:
                    logger.info("Manually creating default signature using columns {} and target {}".format(params.columns, params.target_name))
                    _, DSS_MLFLOW_TYPE = get_mlflow_dss_mappings()
                    input_columns = list(filter(lambda feature: feature["name"] != params.target_name, params.columns))
                    target = list(filter(lambda feature: feature["name"] == params.target_name, params.columns))
                    signature = ModelSignature(inputs=Schema([
                        ColSpec(DSS_MLFLOW_TYPE[feature["type"]], feature["name"])
                        for feature in input_columns
                    ]), outputs=Schema([
                        ColSpec(DSS_MLFLOW_TYPE[feature["type"]], feature["name"])
                        for feature in target]))
                    logger.info("Computed signature: {}".format(signature))
                    mlmodel.signature = signature
            else:
                logger.info("Model already has a signature - using it")

            # We need to first log the model as an artefact of a run of an experiment
            artifact_path = "imported_model"
            experiment = mlflow.get_experiment_by_name(params.experiment_name)
            experiment_id = experiment.experiment_id if experiment else mlflow.create_experiment(params.experiment_name)
            with mlflow.start_run(experiment_id=experiment_id) as run:
                run_id = run.info.run_id

                # add some extra fields to MLmodel then save
                mlmodel.artifact_path = artifact_path
                mlmodel.run_id = run_id
                logger.info("MLmodel file will be updated at path: {}".format(mlmodel_path))
                self._safe_save_mlmodel(mlmodel, mlmodel_path, model_python_version)

                mlflow.log_artifacts(local_dir=params.model_directory, artifact_path=artifact_path)
                self._record_logged_model(mlmodel)

            # Then, we can register the model in the target registry
            api_ret = mlflow.register_model("runs:/" + run_id + "/" + artifact_path, params.model_name)
            ret = {
                "type": "RequestModelRegistrationResponse",
                "name": api_ret.name,
                "description": api_ret.description,
                "status": api_ret.status,
                "statusMessage": api_ret.status_message,
                "version": api_ret.version,
                "runId": run_id,
                "experimentId": experiment_id

            }
            self.link.send_json(ret)

    def _safe_load_mlmodel(self, mlmodel_path):
        """
        Model signature loading may fail on old MLflow versions (before 2.10.0) if inputs or outputs contains "required" or "optional" fields.
        If it happens, let just remove those erroneous fields.

        :param mlmodel_path: Path to the MLModel file
        :type mlmodel_path: str
        :return: The loaded model from MLModel file
        :rtype: Model
        """
        from mlflow.models import Model

        with open(mlmodel_path) as f:
            model_dict = yaml.safe_load(f.read())
            try:
                return Model.from_dict(model_dict)
            except TypeError:
                logger.warning("Model signature may be not compatible with your version of MLflow. We will try to convert the signature to the expected format.")
                logger.warning("Updating MLFlow to version 2.10.0+ should remove this warning.")
                self._delete_unsupported_fields_from_signature(model_dict)
                return Model.from_dict(model_dict)
            
    def _safe_save_mlmodel(self, mlmodel, mlmodel_path, python_version):
        """
        Old versions of python may force to use a virtual env with mlflow version 1 instead of mlflow version 2. 
        But mlflow 1 doesn't support all the fields in the signature so we need to remove them if the python version of the model is too old.

        :param mlmodel: MLModel object
        :type mlmodel: Model
        :param mlmodel_path: Path to the MLModel file
        :type mlmodel_path: str
        :param python_version: Python version of the model
        :type python_version: str
        """
        model_dict = mlmodel.to_dict()

        if self._should_update_signature_on_disk(python_version):
            logger.warning("Python version of the MLflow model is 3.7 or older. We will convert the signature of the MLModel to insure it's readable with a old version of the mlflow python library.")
            self._delete_unsupported_fields_from_signature(model_dict)

        with open(mlmodel_path, "w") as out:
            yaml.safe_dump(model_dict, stream=out, default_flow_style=False)

    def _get_python_version_of_model(self, conda_path):
        """
        Try to find the python version of the python model in the conda.yaml file.
        If conda.yaml is not found or if there is no python definition, it's OK, we will assume everything will work out of the box.

        :param conda_path: Path for the conda.yaml file
        :type conda_path: str
        :return: Python version of the model or None if it was impossible to get it
        :rtype: (str | None)
        """
        try:
            with open(conda_path) as f:
                conda_dict = yaml.safe_load(f.read())
                if "dependencies" in conda_dict:
                    for dependency in conda_dict["dependencies"]:
                        if dependency.startswith("python="):
                            return dependency.replace("python=", "").strip()
            return None
        except Exception:
            logger.warning("Impossible to read conda.yaml file. We will assume python 3.8+ is used for this conda env.", exc_info=True)
            return None
    
    def _should_update_signature_on_disk(self, python_version):
        """
        mlflow 2 doesn't support python version before python 3.8.
        So if the python version of the model is python 3.7 or older, mlflow 1 will be used when creating a conda env during Databricks endpoint initialization.
        But mlflow 1 doesn't support 'required' and 'optional' fields in schemas so we have to delete them from YAML file.

        :param python_version: Python version of the model
        :type python_version: str
        :return: If you should update the signature of the MLModel or not
        :rtype: bool
        """
        if python_version is None:
            # We don't know. Let's do nothing.
            return False

        version_elems = python_version.split(".")
        if len(version_elems) < 2:
            # Strange version. We don't know. Let's do nothing.
            return False
        if version_elems[0] == "2":
            # python 2 is definitely old
            return True
        if int(version_elems[1]) < 8:
            # python 3.6 and 3.7 are old
            return True
        
        # python 3.8+ is OK, no modification needed
        return False

    def _delete_unsupported_fields_from_signature(self, model_dict):
        """
        mlflow.types.ColSpec#__init__() is in the form __init__(self, type, name) for MLFlow versions prior to 2.4.0.
        Let's remove other kwargs

        :param model_dict: Dictionary of the MLModel file
        :type model_dict: dict
        """
        if "signature" not in model_dict:
            return

        signature = model_dict["signature"]
        valid_fields = ["type", "name"]
        if "inputs" in signature:
            inputs = json.loads(signature["inputs"])
            clean_inputs = [self._delete_colspec_fields_not_in(valid_fields, obj) for obj in inputs]
            signature["inputs"] = json.dumps(clean_inputs)
        if "outputs" in signature:
            outputs = json.loads(signature["outputs"])
            clean_outputs = [self._delete_colspec_fields_not_in(valid_fields, obj) for obj in outputs]
            signature["outputs"] = json.dumps(clean_outputs)

    def _delete_colspec_fields_not_in(self, valid_fields, obj):
        if obj["type"] == "tensor": # we don't need to clean TensorSpec, only ColSpec
            return obj

        return { k: v for k, v in obj.items() if k in valid_fields }

    def _record_logged_model(self, mlmodel):
        """
        Courtesy of mlflow/models/model.py#log(...).
        Since it calls an internal method, we catch all the possible exceptions, including AttributeError if method does not exist.

        :param mlmodel: The model coming from the MLmodel file
        :type mlmodel: mlflow.models.Model
        """

        import mlflow
        from mlflow.exceptions import MlflowException

        try:
            import mlflow.tracking.fluent
            mlflow.tracking.fluent._record_logged_model(mlmodel)
        except AttributeError:
            # _record_logged_model() method does not exist
            logger.warning("Recording additional metadata for the model failed. You may use a too old version of MLflow (< 1.7.0) or an internal method has been removed.", exc_info=True)
        except MlflowException:
            # We need to swallow all mlflow exceptions to maintain backwards compatibility with older tracking servers.
            logger.exception("Logging model metadata to the tracking server has failed, possibly due older server version.")
        except Exception:
            # Another unknown exception. Since this method is not crucial, let's catch the exception and continue the registration
            logger.exception("Recording additional metadata for the model failed.")

    def start(self):
        command = {"type": "undefined"}
        try:
            while True:
                command = self.link.read_json()
                if command["type"] == "RequestDownload":
                    request_download = RequestDownload(command)
                    self._handle_request_download(request_download)
                elif command["type"] == "ListRegisteredModels":
                    list_registered_models = ListRegisteredModels(command)
                    self._handle_list_registered_models(list_registered_models)
                elif command["type"] == "ListRegisteredModelVersions":
                    list_registered_model_versions = ListRegisteredModelVersions(command)
                    self._handle_list_registered_model_versions(list_registered_model_versions)
                elif command["type"] == "RequestModelRegistration":
                    request_model_registration = RequestModelRegistration(command)
                    self._handle_register_model(request_model_registration)
                else:
                    raise Exception("Unexpected command " + command)
        except EOFError:
            logger.info("Connection with client closed")
        except:
            traceback.print_exc()
            error = get_json_friendly_error()
            self.link.send_json({'type': "{}Response".format(command["type"]), 'error': error})
            logger.error("Error during MLflow operation {}".format(error))
            raise


def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()
    protocol_handler = MLflowUtilsProtocol(link)
    try:
        protocol_handler.start()
    finally:
        link.close()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
