import json
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, Generator, List, Optional, Union

import dataiku
from common.backend.constants import KEYS_TO_REMOVE_FROM_LOGS
from common.backend.models.base import LlmHistory, MediaSummary
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.file_utils import get_file_mime_type
from common.backend.utils.json_utils import mask_keys
from common.backend.utils.picture_utils import get_image_bytes
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import (
    DSSLLMCompletionQuery,
    DSSLLMCompletionResponse,
    DSSLLMStreamedCompletionChunk,
    DSSLLMStreamedCompletionFooter,
    _SSEClient,
)
from requests import Response


class AnswersCompletionQuery(DSSLLMCompletionQuery):
    def __init__(self, agent_id: str, project_key, webapp_id, agent_name):
        super().__init__(llm=None)
        self.agent_id = agent_id
        self.project_key = project_key
        self.webapp_id = webapp_id
        self.agent_name = agent_name
        self._settings = {}  # type: ignore
        self.webapp_config = dataiku_api.webapp_config

    @property
    def settings(self):
        """
        :return: The completion query settings.
        :rtype: dict
        """
        return self._settings

    def __replace_path_in_history(
        self,
        history_item: LlmHistory,
        portal_media_summary: MediaSummary,
        agents_uploads: Dict[str, Dict[str, MediaSummary]],
        path_key: str,
    ) -> None:
        original_file_name = portal_media_summary["original_file_name"]
        answers_media_summary: MediaSummary = agents_uploads.get(original_file_name, {}).get(self.agent_id, {})

        if not answers_media_summary:
            logger.debug(f"No data found for id '{self.agent_id}' under file name '{original_file_name}'.")
            return

        if path := answers_media_summary.get(path_key, ""):
            current_path: str = str(portal_media_summary.get(path_key, ""))
            if not current_path:
                logger.error(f"Invalid current path for key '{path_key}'.")
                return

            logger.debug(f"Found media summary {path_key} in history replacing: {current_path} by {path}")
            history_item["input"] = history_item["input"].replace(current_path, str(path))
            history_item["output"] = history_item["output"].replace(current_path, str(path))
        else:
            logger.info(f"No replacement needed for '{path_key}'.")

    def process_history(
        self, chat_history: List[LlmHistory], agents_uploads: Optional[Dict[str, Dict[str, MediaSummary]]] = None
    ):
        processed_history = []
        current_timestamp = datetime.now().timestamp()
        portal_media_summaries = self.settings.get("media_summaries", [])
        # TODO: handle generated media, timestamp in history
        # for now using current timestamp
        for history_item in chat_history:
            if portal_media_summaries and agents_uploads:
                for portal_media_summary in portal_media_summaries:
                    logger.debug(f"Checking media summary in history: {portal_media_summary.get('file_path')}")
                    for key in ["file_path", "metadata_path"]:
                        self.__replace_path_in_history(
                            history_item,
                            portal_media_summary,
                            agents_uploads,
                            key,
                        )
            processed_history.append(
                {"query": history_item["input"], "answer": history_item["output"], "timestamp": current_timestamp}
            )
        return processed_history

    def build_query_context(self, agents_uploads: Optional[Dict[str, Dict[str, MediaSummary]]] = None):
        chat_history = self.settings.get("chat_history", [])
        current_timestamp = datetime.now().timestamp()
        processed_history = self.process_history(chat_history, agents_uploads)

        context = {
            "applicationType": "webapp",
            "applicationId": "portal-" + self.settings.get("app_id", ""),
            "botId": "portal",
            "device": self.settings.get("user_agent") or "",
            # "team": "todo",
            "timestamp": current_timestamp,
            "history": processed_history,
        }
        return context

    def build_query_user_preferences(self) -> Dict:
        """Build the user preferences by getting it from the user profile settings or return default one."""
        default_user_preferences = {
            "language": {"value": "en", "description": "User's language"},
        }
        return self.settings.get("user_profile") or default_user_preferences

    def build_query(
        self,
        uploads: Optional[List[MediaSummary]] = None,
        previous_uploads: Optional[Dict[str, Dict[str, MediaSummary]]] = None,
    ):
        query = {
            "user": self.settings.get("user", "undefined_user"),
            "query": self.cq["messages"][0]["content"],
            "context": self.build_query_context(previous_uploads),
            "conversationId": "",
            "chatSettings": {"createConversation": False, "withTitle": False, "requestedResponseFormat": "markdown"},
            "userPreferences": self.build_query_user_preferences(),
        }
        if uploads:
            query["files"] = [
                {
                    "name": upload.get("original_file_name"),
                    "path": upload.get("file_path"),
                    "thumbnail": upload.get("preview"),
                    "chainType": upload.get("chain_type"),
                    "jsonFilePath": upload.get("metadata_path"),
                    "format": "TODO",
                }
                for upload in uploads
            ]
        return query

    def get_files(self, media_summaries: List[MediaSummary]):
        files_list = []
        try:
            logger.debug("Retrieving files to send to Answers")
            for media_summary in media_summaries:
                file_path = media_summary.get("file_path")
                file_name = media_summary.get("original_file_name", "")
                if not file_path or not file_name:
                    pass
                file_bytes = get_image_bytes(file_path)  # type: ignore
                if file_bytes:
                    file_like_object = BytesIO(file_bytes)  # type: ignore
                    mime_type = get_file_mime_type(file_name)
                    if mime_type is None:
                        mime_type = "application/octet-stream"
                    files_list.append(("files[]", (file_name, file_like_object, mime_type)))

        except Exception as e:
            logger.exception(f"Error while getting files: {e}")

        logger.debug(f"Files to send to Answers: {files_list}")
        return files_list

    def handle_upload_response(self, response_json: Dict):
        try:
            if response_json.get("status") == "ok":
                logger.debug(f"Successfully uploaded files to Answers")
                return response_json.get("data", {}).get("media_summaries", [])
            else:
                logger.error(f"Error while uploading files to Answers: {response_json.get('message')}")
        except Exception as e:
            logger.exception(f"Error while handling upload response: {e}")

    def execute_upload(self):
        client = dataiku.api_client()
        if self.webapp_config.get("disable_ssl_verification"):
            client._session.verify = False 
        project = client.get_project(self.project_key)
        webapp = project.get_webapp(self.webapp_id)
        backend = webapp.get_backend_client()
        uploads = None
        summaries = self.settings.get("media_summaries") or None
        agents_files_uploads = self.settings.get("agents_files_uploads") or {}
        if summaries:
            # First Handle file uploads before answering question
            logger.debug(f"Agents files uploads: {mask_keys(agents_files_uploads, KEYS_TO_REMOVE_FROM_LOGS)}")
            # only take no previously uploaded files
            new_files_to_upload = [
                file
                for file in summaries
                if file.get("original_file_name") not in agents_files_uploads
                or self.agent_id not in agents_files_uploads[file.get("original_file_name")]
            ]
            logger.debug(f"New files to upload: {[file.get('original_file_name') for file in new_files_to_upload]}")
            files = self.get_files(new_files_to_upload)
            if files:
                ret = backend.session.post(
                    backend.url_for_path("/api/file/upload"),
                    files=files,
                )
                if ret.ok and ret.text:
                    upload_response = json.loads(ret.text)
                    uploads = self.handle_upload_response(upload_response)
                    logger.debug(
                        f"Successfully uploaded to Answer: {mask_keys(upload_response, KEYS_TO_REMOVE_FROM_LOGS)}"
                    )
                    return uploads
                else:
                    # TODO should we return or proceed here by sending extracted info?
                    raise Exception("Failed to upload file")
        return None

    def execute(self):
        """
        Run the completion query and retrieve the LLM response.

        :returns: The LLM response.
        :rtype: :class:`DSSLLMCompletionResponse`
        """
        try:
            logger.debug(
                f"AnswersCompletionQuery.execute with query: {mask_keys(self.cq, KEYS_TO_REMOVE_FROM_LOGS)}"
            )
            client = dataiku.api_client()
            if self.webapp_config.get("disable_ssl_verification"):
                logger.info("SSL verification is disabled")
                client._session.verify = False
            project = client.get_project(self.project_key)
            webapp = project.get_webapp(self.webapp_id)
            backend = webapp.get_backend_client()
            uploads: Optional[List[MediaSummary]] = None
            agents_files_uploads = self.settings.get("agents_files_uploads") or {}
            try:
                uploads = self.execute_upload()
            except Exception as e:
                logger.exception(f"Error while uploading files: {e}")
                return AnswersCompletionResponse(
                    json.dumps({f"answer": f"Error while uploading files to Answers {self.agent_name}"}), None
                )
            query = self.build_query(uploads, agents_files_uploads)
            with backend.session.post(
                backend.url_for_path("/api/ask"),
                json.dumps(query),
                headers={"Content-Type": "application/json", "Accept": "application/json"},
            ) as ret:
                if ret.text:
                    logger.debug(
                        f"AnswersCompletionQuery.execute response: {mask_keys(ret.text, KEYS_TO_REMOVE_FROM_LOGS)}"
                    )
                    return AnswersCompletionResponse(ret.text, uploads)
                else:
                    return AnswersCompletionResponse(json.dumps({"answer": "No answer found"}), None)
        except:
            logger.exception("Calling Answers failed")
            return AnswersCompletionResponse(
                json.dumps({f"answer": "Error while calling Answers {webapp} in project: {project}"}), None
            )

    def execute_streamed(
        self,
    ) -> Generator[Union[DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter], Any, None]:
        logger.debug(f"AnswersCompletionQuery.execute with query: {self.cq},{self.settings}")
        client = dataiku.api_client()
        if self.webapp_config.get("disable_ssl_verification"):
            logger.info("SSL verification is disabled")
            client._session.verify = False 
        project = client.get_project(self.project_key)
        webapp = project.get_webapp(self.webapp_id)
        backend = webapp.get_backend_client()
        uploads: Optional[List[MediaSummary]] = None
        agents_files_uploads = self.settings.get("agents_files_uploads") or {}
        try:
            uploads = self.execute_upload()
        except Exception as e:
            logger.exception(f"Error while uploading files: {e}")
            error_message = f"Failed to upload documents to {self.agent_name}."
            # TODO should we return or proceed here by sending extracted info?
            yield DSSLLMStreamedCompletionChunk({"text": error_message})
            yield DSSLLMStreamedCompletionFooter({"finishReason": "error", "errorMessage": error_message})
            return
        query = self.build_query(uploads, agents_files_uploads)
        try:
            with backend.session.post(
                backend.url_for_path("/api/ask"),
                json.dumps(query),
                headers={"Content-Type": "application/json", "Accept": "text/event-stream"},
                stream=True,
            ) as ret:
                response: Response = ret

                if not response.ok:
                    logger.error(
                        f"Communication with Answers {self.agent_name} failed. Status code: {response.status_code}."
                    )
                    error_message = f"Communication with Answers {self.agent_name} failed."
                    yield DSSLLMStreamedCompletionChunk({"text": error_message})
                    yield DSSLLMStreamedCompletionFooter(
                        {
                            "finishReason": "error",
                            "errorMessage": error_message,
                        }
                    )
                    return

                client = _SSEClient(response.raw)
                for event in client.iterevents():
                    data = json.loads(event.data)

                    if event.event == "completion-chunk":
                        yield DSSLLMStreamedCompletionChunk(data)
                    elif event.event == "completion-end":
                        if data.get("error"):
                            raise Exception(f"Answers sent back an error: '{data}'")
                        if data and uploads:
                            data["media_summaries"] = uploads
                        yield DSSLLMStreamedCompletionFooter(data)
                    else:
                        logger.debug(f"Ignore event {event.event} with data {event.data}.")
        except Exception as e:
            logger.exception(f"Failed to get answers completion: {e}")
            error_message = "Failed fetching complete response from answers."
            yield DSSLLMStreamedCompletionChunk({"text": error_message})
            yield DSSLLMStreamedCompletionFooter({"finishReason": "error", "errorMessage": error_message})


class AnswersCompletionResponse(DSSLLMCompletionResponse):
    """
    A handle to interact with a answers completion response.

    .. important::
        Do not create this class directly, use :meth:`dataikuapi.dss.llm.DSSLLMCompletionQuery.execute` instead.
    """

    def __init__(self, raw_resp, uploads):
        self._raw = raw_resp
        self._json = None
        self._uploads = uploads

    @property
    def json(self):
        """
        :return: Answers response parsed as a JSON object
        """
        if not self.success:
            if self._json:
                error_message = self._json.get("data", {}).get("errorMessage", "An unknown error occurred")
            else:
                error_message = "Failed to parse response to Json"
            raise Exception(error_message)

        if self._json is None:
            self._json = json.loads(self._raw)
        if self._uploads and self._json.get("data") and not self._json["data"].get("media_summaries"):
            self._json["data"]["media_summaries"] = self._uploads
        return self._json

    @property
    def success(self):
        """
        :return: The outcome of the completion query.
        :rtype: bool
        """
        try:
            resp_json = json.loads(self._raw)
            self._json = resp_json
            return "ok" in resp_json.get("status")
        except:
            return False

    @property
    def text(self):
        """
        :return: The raw text of the Answers response.
        :rtype: Union[str, None]
        """
        return self._raw

    @property
    def tool_calls(self):
        """
        :return: The tool calls of the Answers response.
        :rtype: Union[list, None]
        """
        return None

    @property
    def log_probs(self):
        """
        :return: The log probs of the Answers response.
        :rtype: Union[list, None]
        -- classification
        """
        return None
