import json
import logging

from dataiku.external_ml.proxy_model.common import BatchConfiguration
from dataiku.external_ml.proxy_model.common import ChunkedAndFormatGuessingProxyModel
from dataiku.external_ml.proxy_model.common import ProxyModelEndpointClient
from dataiku.external_ml.proxy_model.vertex_ai.inputformat import VertexAIDefaultWriter
from dataiku.external_ml.proxy_model.vertex_ai.outputformat import VertexAIDefaultReader

logger = logging.getLogger(__name__)

VERTEX_AI_BYTES_LIMIT = 1400000  # Vertex AI limits at 1.5 MB per request, so let's be a tad under

# Formats will be tried in the defined order. Order is important, so don't change it unless you really now what you're doing.
OutputFormats = [
    VertexAIDefaultReader
]
InputFormats = [
    VertexAIDefaultWriter
]


class VertexAIEndpointClient(ProxyModelEndpointClient):
    def __init__(self, location, resource_name, proxies, credentials, token):
        self.proxies = proxies
        self.credentials = credentials
        self.token = token
        self.uri = "https://{}-aiplatform.googleapis.com/v1/{}:predict".format(location, resource_name)
        super(VertexAIEndpointClient, self).__init__()

    def call_endpoint(self, parameters_dict, content_type):
        # content_type is unused for Vertex
        import google.auth.transport.requests
        import requests
        if self.token:
            token = self.token
        elif self.credentials:
            session = requests.Session()
            if self.proxies:
                session.proxies.update(self.proxies)
            session.verify = False
            try:
                self.credentials.refresh(google.auth.transport.requests.Request(session))
            except Exception as e:
                raise Exception("Exception when using credentials to get an authentication token: {}".format(e))
            token = self.credentials.token
        if token:
            headers = {"Authorization": "Bearer {}".format(token)}
        else:
            headers = None
        response = requests.post(
            self.uri,
            proxies=self.proxies,
            headers=headers,
            json=parameters_dict # The data to send to the endpoint for prediction
        )
        if response.status_code != requests.codes.ok:
            response.raise_for_status()
        return response.json()


class VertexAIProxyModel(ChunkedAndFormatGuessingProxyModel):
    def __init__(self, project_id, endpoint_id, meta, location=None, connection=None, **kwargs):
        self.project_id = project_id
        self.endpoint_id = endpoint_id
        self.location = location
        self.connection = connection
        super(VertexAIProxyModel, self).__init__(meta.get("predictionType"), meta.get("intToLabelMap"), InputFormats, OutputFormats, meta.get("inputFormat"), meta.get("outputFormat"), BatchConfiguration(VERTEX_AI_BYTES_LIMIT))

    def get_client(self):
        import google.auth
        from google.oauth2 import service_account

        proxy = self.get_proxy()
        if proxy:
            proxies = {
                "http": "http://" + proxy,
                "https": "http://" + proxy,
            }
            logger.debug("Using proxies: {}".format(proxies))
        else:
            logger.debug("No applicative proxy configuration. Proxies may still be defined "
                         "with HTTP_PROXY and HTTPS_PROXY")
            proxies = None

        credentials = None
        token = None

        scopes = ["https://www.googleapis.com/auth/cloud-platform"]

        if self.connection is not None:
            if VertexAIProxyModel.runs_on_real_api_node():
                logger.info("NOT getting connection params from connection {}. Authentication will be performed from environment".format(self.connection))
                credentials, _ = google.auth.default()
            else:
                logger.info("Using connection {} to authenticate".format(self.connection))
                dss_connection = VertexAIProxyModel.get_connection_info(self.connection, None, "VertexAIModelDeployment", "Google Vertex AI connection")
                params = dss_connection.get_resolved_params()
                auth_type = params.get("authType")
                if auth_type == "ENVIRONMENT":
                    logger.debug("Retrieving credentials from environment")
                    credentials, _ = google.auth.default()
                elif auth_type == "KEYPAIR":
                    logger.debug("Using keypair configured in connection")
                    key = params.get("appSecretContent")
                    if key is None:
                        raise Exception("Key pair auth configured, but no private key available")
                    info = json.loads(key)
                    credentials = service_account.Credentials.from_service_account_info(info, scopes=scopes)
                elif auth_type == "OAUTH":
                    logger.debug("Using oauth per user authentication")
                    token = dss_connection.get_oauth2_credential()["accessToken"]
                else:
                    raise Exception("Unhandled auth type: {}".format(auth_type))
        else:
            logger.debug("No connection configured. Retrieving credentials from environment.")
            credentials, _ = google.auth.default(scopes=scopes)

        logger.info("Initialized VertexAI client with endpoint name '{endpoint_id}' "
                    "from project '{project_id}' in location '{location}'."
                    "".format(endpoint_id=self.endpoint_id,
                              project_id=self.project_id,
                              location=self.location))
        endpoint_name = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=self.project_id,
                                                                                              location=self.location,
                                                                                              endpoint=self.endpoint_id)
        return VertexAIEndpointClient(self.location, endpoint_name, proxies, credentials, token)
