import json
import logging
import sys

import pandas as pd
from dataiku.core.dataset import Dataset

logger = logging.getLogger(__name__)


class InvalidColumnNamesException(Exception):
    pass


def get_mlflow_dss_mappings():
    import mlflow

    MLFLOW_DSS_TYPE = {
        repr(mlflow.types.schema.DataType.integer): "int",
        repr(mlflow.types.schema.DataType.long): "bigint",
        repr(mlflow.types.schema.DataType.float): "float",
        repr(mlflow.types.schema.DataType.double): "double",
        repr(mlflow.types.schema.DataType.string): "string",
        repr(mlflow.types.schema.DataType.binary): "object",
        repr(mlflow.types.schema.DataType.datetime): "date",
        repr(mlflow.types.schema.DataType.boolean): "boolean",
    }

    DSS_MLFLOW_TYPE = {v: k for k, v in MLFLOW_DSS_TYPE.items()}
    return MLFLOW_DSS_TYPE, DSS_MLFLOW_TYPE


def read_meta(model_folder_context):
    import mlflow

    with model_folder_context.get_folder_path_to_read() as model_folder_path:
        mlflow_model = mlflow.pyfunc.load_model(model_folder_path)

    mlflow_imported_model_filename = "mlflow_imported_model.json"
    mlflow_imported_model = model_folder_context.read_json(mlflow_imported_model_filename)

    mlflow_model_metadata = mlflow_model.metadata
    try:
        mlflow_imported_model["timeCreated"] = (pd.to_datetime(mlflow_model_metadata.utc_time_created) - pd.Timestamp("1970-01-01")) // pd.Timedelta('1ms')
    except Exception:
        pass

    try:
        flavors = mlflow_model_metadata.flavors
        for top_level in flavors:
            if top_level == 'python_function':
                to_populate = mlflow_imported_model["pyfuncLabels"]
            else:
                to_populate = mlflow_imported_model["flavorsLabels"]
            for second_level in flavors[top_level]:
                to_populate.append({
                    "key": "{}:{}".format(top_level, second_level),
                    "value": str(flavors[top_level][second_level])
                })

    except Exception:
        pass

    try:

        MLFLOW_DSS_TYPE, _ = get_mlflow_dss_mappings()
        signature = mlflow_model_metadata.signature
        if signature and signature.inputs:
            features = mlflow_imported_model.get("features", [])
            features.extend([
                {"name": feature["name"], "type": MLFLOW_DSS_TYPE[feature["type"]]}
                for feature in json.loads(signature.to_dict()["inputs"])
            ])
            mlflow_imported_model["features"] = features
    except Exception:
        pass

    mlflow_imported_model["pythonVersion"] = "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor, sys.version_info.micro)

    model_folder_context.write_json(mlflow_imported_model_filename, mlflow_imported_model)


def read_user_meta(model_folder_context):
    return model_folder_context.read_json("user_meta.json")


def set_formats(model_folder_context, input_dataset, sampling, target, input_format_name, output_format_name):
    """
    If we are dealing with a proxy model:
      - if input_format_name="GUESS" or output_format_name="GUESS", try to guess the input/output format that should be used
       to talk to the endpoint by generating queries using format_guessing_dataset and cycling through supported
       input/output formats. Eventually drops the "target" column if set and if it exists in input_dataset.
      - this method will save input/output formats (guessed or passed as parameters) back to the saved model folder.

    If we are not dealing with a proxy model, this method does nothing.
    """
    import mlflow

    mlflow_imported_model_filename = "mlflow_imported_model.json"
    mlflow_imported_model = model_folder_context.read_json(mlflow_imported_model_filename)
    pmc = mlflow_imported_model.get("proxyModelVersionConfiguration")
    if pmc is None:
        logger.info("Not dealing with a proxy model, will not try to set formats.")
        return

    logger.info("Loading MLflow model in {}".format(model_folder_context))
    with model_folder_context.get_folder_path_to_read() as model_folder_path:
        model = mlflow.pyfunc.load_model(model_folder_path)._model_impl
    if not hasattr(model, "guess_formats"):
        logger.info("Not dealing with a format-dependent proxy model, will not try to set formats.")
        return

    if input_dataset is not None:
        logger.info("Trying to guess format using dataset {}".format(input_dataset))
        logger.info("Loading first rows of dataset {}".format(input_dataset))
        df = Dataset(input_dataset).get_dataframe(sampling=sampling)
        if target is not None and target in df:
            logger.info("Dropping target {} on dataset {}".format(target, input_dataset))
            df = df.drop(target, axis=1)
        input_format, output_format = model.guess_formats(df, input_format_name, output_format_name)
        input_format_name = input_format.NAME
        output_format_name = output_format.NAME
    else:
        if input_format_name == "GUESS":
            raise ValueError("You must provide an input_format when not providing an input_dataset")
        if output_format_name == "GUESS":
            raise ValueError("You must provide an output_format when not providing an input_dataset")

    model.set_input_output_formats(model.supported_input_formats, model.supported_output_formats, input_format_name, output_format_name)

    if model.input_format is None or model.output_format is None:
        raise ValueError("Could not set input or output format, parameters were: model_folder: {}, "
                         "input_dataset: {}, "
                         "target: {}, "
                         "input_format_name: {}, "
                         "output_format_name: {}".format(model_folder_context, input_dataset, target, input_format_name, output_format_name))
    mlflow_imported_model["inputFormat"] = input_format_name
    mlflow_imported_model["outputFormat"] = output_format_name
    model_folder_context.write_json(mlflow_imported_model_filename, mlflow_imported_model)


def set_signature(model_folder_context, dataset, sampling, target, features):
    """This method will attempt to set the signature on an MLflow model.

    - check if the signature is missing
    - if features is not None or empty, manually create a signature from the list of features
    - if dataset is not None, try to infer a more complete signature from it
    - if a signature was created, write a backup MLmodel without the signature
    and then override the MLmodel with signature
    """
    import mlflow
    from mlflow.models.signature import ModelSignature
    from mlflow.models.signature import infer_signature
    from mlflow.types.schema import ColSpec
    from mlflow.types.schema import Schema

    logger.info("Loading MLflow model in {}".format(model_folder_context))
    with model_folder_context.get_folder_path_to_read() as model_folder_path:
        model = mlflow.pyfunc.load_model(model_folder_path)
    metadata = model.metadata  # Content of the MLmodel file

    logger.info("Checking MLflow model signature")

    signature = None
    if features is not None and len(features) > 0:
        logger.info("Manually creating default signature using features {}".format(features))
        _, DSS_MLFLOW_TYPE = get_mlflow_dss_mappings()
        signature = ModelSignature(inputs=Schema([
            ColSpec(DSS_MLFLOW_TYPE[feature["type"]], feature["name"])
            for feature in features
        ]))

    if dataset is not None:
        logger.info("Trying to infer signature from dataset {}".format(dataset))
        try:
            logger.info("Loading first rows of dataset {}".format(dataset))
            df = Dataset(dataset).get_dataframe(sampling=sampling, infer_with_pandas=False)
            logger.info("Dropping target {} on dataset {}".format(target, dataset))
            df = df.drop(target, axis=1)
            logger.info("Inferring signature from dataframe and model prediction")
            signature = infer_signature(df, model.predict(df))
        except Exception as e:
            msg = str(e)
            if msg.startswith("Failed to connect"):
                raise
            logger.info("Could not infer signature from dataset {} using MLflow model {}. Error {}".format(dataset, model_folder_context, msg))

    if signature is not None:
        logger.info("Signature specified for MLflow model: {}".format(signature))

        if metadata.signature is not None:
            logger.info("Signature from MLModel file: {}".format(signature))
            logger.info("Comparing MLModel signature with the one specified by DSS.")
            _compare_signatures(metadata.signature, signature)
        else:
            logger.info("MLModel file does not have a signature. Let's add the specified one.")

            with model_folder_context.get_file_path_to_write("MLmodel_backup") as mlmodel_path_backup:
                logger.info("Saving a copy of MLmodel in: {}".format(mlmodel_path_backup))
                with open(mlmodel_path_backup, "w") as f_backup:
                    metadata.to_yaml(f_backup)

            metadata.signature = signature  # Adding the signature to metadata
            with model_folder_context.get_file_path_to_read("MLmodel") as mlmodel_path:
                logger.info("Writing signature in: {}".format(mlmodel_path))
                with open(mlmodel_path, "w") as f:
                    metadata.to_yaml(f)


def _compare_signatures(model_sig, dss_sig):
    """
    Compare signatures and throw an exception if they are incompatible

    :param model_sig: Signature coming from MLmodel
    :type model_sig: ModelSignature
    :param dss_sig: Signature coming from DSS
    :type dss_sig: ModelSignature
    """
    # We only care about inputs, not outputs
    _compare_schemas(model_sig.inputs, dss_sig.inputs)


def _compare_schemas(model_schema, dss_schema):
    """
    Compare schemas and throw an exception if they are incompatible

    :param model_schema: Schema coming from MLmodel
    :type model_schema: Schema
    :param dss_schema: Schema coming from DSS
    :type dss_schema: Schema
    """
    if model_schema is None or dss_schema is None:
        # Nothing to compare because one value is null
        return
    if model_schema.is_tensor_spec() or dss_schema.is_tensor_spec():
        # Skipping comparison because Tensor specs are hard to compare since they are multidimensional
        return

    if model_schema.has_input_names() and dss_schema.has_input_names():
        _lenient_compare_schemas(model_schema.inputs, dss_schema.inputs)
    else:
        # No columns names: we will only compare the type
        if model_schema.input_types() != dss_schema.input_types():  # Let's compare feature types
            # We can't use Schema.input_types_dict() because it was only introduced in MLflow 2.4.0
            logger.warning(
                f"Types from specified features and features coming from MLModel signature are different. Types from signature are {model_schema.input_types()} but specified types are {dss_schema.input_types()}."
            )


def _lenient_compare_schemas(model_inputs_specs, dss_inputs_specs):
    """Lenient comparison of inputs columns.
    Signatures may differ a bit. We only raise an exception if the models are incompatible.

    :param model_inputs_specs: List of ColSpec coming from MLmodel
    :type model_inputs_specs: List[ColSpec]
    :param dss_inputs_specs: List of ColSpec coming from DSS
    :type dss_inputs_specs: List[ColSpec]
    :raises InvalidColumnNamesException: if names of columns are different
    """
    model_inputs_dict = _get_columns_dict(model_inputs_specs)
    dss_inputs_dict = _get_columns_dict(dss_inputs_specs)

    missing_required_columns = []
    missing_optional_columns = []
    bad_type_columns = []
    for model_col in model_inputs_dict.values():
        dss_col = dss_inputs_dict.pop(model_col["name"], None)
        if dss_col is None:
            if model_col.get("required", True):
                missing_required_columns.append(model_col)
            else:
                missing_optional_columns.append(model_col)
        elif model_col["type"] != "any" and dss_col["type"] != model_col["type"]:
            bad_type_columns.append((model_col, dss_col))

    extra_columns = [remaining_col for remaining_col in dss_inputs_dict.values()]

    if len(missing_required_columns) > 0:
        raise InvalidColumnNamesException(
            f"Some columns declared in MLModel signature are missing: {missing_required_columns}."
        )
    if len(missing_optional_columns) > 0:
        logger.info(f"Some optional columns declared in MLModel signature are missing: {missing_optional_columns}.")
    if len(bad_type_columns) > 0:
        signature_bad_types = [c[0] for c in bad_type_columns]
        dss_bad_types = [c[1] for c in bad_type_columns]
        logger.warning(
            f"Types of some columns are different from MLModel signature. Types from signature are {signature_bad_types} but specified types are {dss_bad_types}."
        )
    if len(extra_columns) > 0:
        logger.warning(f"Some columns are not present in the MLModel signature: {extra_columns}.")


def _get_columns_dict(inputs_specs):
    """Convert list of ColSpec into a dictionary, for easier access and robustness.
    Indeed, old versions of MLflow.ColSpec doesn't have the property 'required'

    :param inputs_specs: list of MLflow ColSpec
    :type inputs_specs: List[ColSpec]
    :return: A dictionary indexed by column names
    :rtype: Dict[str, Dict[str, Any]]
    """
    result = {}
    for c in inputs_specs:
        result[c.name] = c.to_dict()
    return result
