import json
import os


from .dss import DSSProxyModel
from .kserve import KServeProxyModel
from .vertex_ai import VertexAIProxyModel
from .sagemaker import SagemakerProxyModel
from .azure_ml import AzureMLProxyModel
from .databricks import DatabricksProxyModel

from dataiku.external_ml.utils import load_external_model_meta
from dataiku.base.folder_context import build_noop_folder_context


import logging

logger = logging.getLogger(__name__)

VALID_PROXY_MODEL_PROTOCOLS = ["sagemaker", "kserve", "dss-api-node", "vertex-ai", "azure-ml", "databricks"]


class InvalidProxyModelProtocolException(Exception):
    pass


def _load_proxy_model(configuration, protocol, meta):
    if protocol not in VALID_PROXY_MODEL_PROTOCOLS:
        raise InvalidProxyModelProtocolException("Protocol {} not implemented".format(protocol))

    # TODO: possibly deserialize the configurations to Python classes rather than dictionaries
    logger.info("Using protocol {}".format(protocol))
    if protocol == "sagemaker":
        return SagemakerProxyModel(
            endpoint_name=configuration["endpoint_name"],
            meta=meta,
            region=configuration["proxyModelConfiguration"].get("region"),
            connection=configuration["proxyModelConfiguration"].get("connection")
        )
    elif protocol == "vertex-ai":
        return VertexAIProxyModel(
            project_id=configuration["proxyModelConfiguration"]["project_id"],
            endpoint_id=configuration["endpoint_id"],
            meta=meta,
            location=configuration["proxyModelConfiguration"].get("region"),
            connection=configuration["proxyModelConfiguration"].get("connection")
        )
    elif protocol == "azure-ml":
        return AzureMLProxyModel(
            subscription_id=configuration["proxyModelConfiguration"]["subscription_id"],
            resource_group=configuration["proxyModelConfiguration"]["resource_group"],
            workspace=configuration["proxyModelConfiguration"]["workspace"],
            endpoint_name=configuration["endpoint_name"],
            meta=meta,
            connection=configuration["proxyModelConfiguration"].get("connection")
        )
    elif protocol == "kserve":
        return KServeProxyModel(
            root_url=configuration["proxyModelConfiguration"]["root_url"],
            model_id=configuration["model_id"],
            headers=configuration["proxyModelConfiguration"]["headers"],
        )
    elif protocol == "dss-api-node":
        return DSSProxyModel(
            root_url=configuration["proxyModelConfiguration"]["root_url"],
            service_id=configuration["service_id"],
            endpoint_id=configuration["endpoint_id"],
        )
    elif protocol == "databricks":
        return DatabricksProxyModel(
            endpoint_name=configuration["endpointName"],
            meta=meta,
            connection=configuration["proxyModelConfiguration"].get("connection")
        )
    else:
        raise NotImplementedError


def _load_pyfunc(model_uri):
    logger.info("Loading Proxy model configuration in MLmodel file")

    logger.info("Loading configuration of proxy model and meta")
    with open(os.path.join(model_uri, "mlflow_imported_model.json")) as config_file:
        configuration = json.load(config_file)["proxyModelVersionConfiguration"]

    meta = load_external_model_meta(build_noop_folder_context(model_uri))

    protocol = configuration["protocol"]
    return _load_proxy_model(configuration, protocol, meta)
