import json
import re
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional

from common.backend.constants import DEFAULT_DECISIONS_GENERATION_ERROR, PROMPT_DATE_FORMAT
from common.backend.models.base import ConversationParams, LlmHistory, MediaSummary
from common.backend.utils.llm_utils import (
    add_history_to_completion,
    append_summaries_to_completion_msg,
    get_llm_completion,
    handle_response_trace,
    parse_error_messages,
)
from common.llm_assist.fallback import get_fallback_completion, is_fallback_enabled
from common.llm_assist.logging import logger
from dataiku.langchain.dku_llm import DKULLM
from dataikuapi.dss.llm import DSSLLMCompletionQuery, DSSLLMCompletionQueryMultipartMessage, DSSLLMCompletionResponse


#TODO to refactor/inherit from GenericDecisionJSONChain
class GenericDBAgent(ABC):
    # Abstract class for Knowledge Bank Agents
    # Implement the prompt property in the child class
    @property
    def formatted_errors(self):
        if len(self.previous_sql_errors) > 0:
            errors = ""
            for idx, error_response in enumerate(self.previous_sql_errors):
                errors += f"""Attempt {idx}:
                Response:
                {error_response["response"]}
                Query:
                {error_response["query"]}
                Error returned
                {error_response["error"]}
                ---------------------------------
                """
            return errors
        logger.debug("errors: NO ERRORS!")
        return ""

    @property
    @abstractmethod
    def graph_system_prompt(self):
        pass

    @property
    @abstractmethod
    def graph_user_prompt(self):
        pass

    @abstractmethod
    def update_graph(self, selected_graph: Optional[Dict[str, Any]], justification: str):
        pass

    @property
    @abstractmethod
    def query_system_prompt(self):
        pass

    @property
    @abstractmethod
    def query_user_prompt(self):
        pass

    @property
    @abstractmethod
    def chosen_tables_and_columns(self) -> str:
        pass

    @property
    @abstractmethod
    def graph_justification(self):
        pass

    @property
    @abstractmethod
    def error_system_prompt(self):
        pass

    @property
    @abstractmethod
    def error_user_prompt(self):
        pass

    @property
    @abstractmethod
    def previous_sql_errors(self):
        pass

    def get_retrieval_graph(
        self,
        decision_llm,
        chat_history,
        user_input,
        conversation_params: ConversationParams,
        first_attempt: bool = True,
    ):
        error_message = f"{DEFAULT_DECISIONS_GENERATION_ERROR} SQL graph: "
        justification = ""
        media_summaries: List[MediaSummary] = conversation_params.get("media_summaries") or []
        previous_media_summaries: List[MediaSummary] = conversation_params.get("previous_media_summaries") or []
        summaries = media_summaries + previous_media_summaries

        text = ""
        try:
            graph_user_prompt = (
                f"Today's date and time: {datetime.now().strftime(PROMPT_DATE_FORMAT)} \n " + self.graph_user_prompt
            )
            logger.debug(f"""DB Metadata identification FINAL PROMPT:
graph_system_prompt: {self.graph_system_prompt}
user_prompt: {graph_user_prompt}
            """)
            completion = get_llm_completion(decision_llm)
            if not first_attempt:
                completion = get_fallback_completion(completion)
            completion.with_message(self.graph_system_prompt, role="system")
            completion = add_history_to_completion(completion, chat_history)
            completion.with_message(self.graph_user_prompt, role="user")
            if summaries:
                msg: DSSLLMCompletionQueryMultipartMessage = completion.new_multipart_message(role="user")
                append_summaries_to_completion_msg(summaries, msg)
            completion.with_message(user_input, role="user")
            resp: DSSLLMCompletionResponse = completion.execute()
            handle_response_trace(resp)
            text = str(resp.text)
            if not resp.text and resp.errorMessage:
                error_message += str(resp.errorMessage)
                raise Exception(error_message)
            logger.debug(f"Generated response for DB queries information: {text}")
            match = re.search(r"\{.*\}", text, re.DOTALL)
            if match:
                json_str = match.group(0)
                response_json = json.loads(json_str)
                tables_and_columns = response_json.get("tables_and_columns")
                suggested_joins = response_json.get("suggested_joins")
                justification = response_json.get("justification")
                if all(q is None for q in [tables_and_columns, suggested_joins]):
                    return self.update_graph({"tables_and_columns": None, "suggested_joins": None}, justification)
                return self.update_graph(
                    {"tables_and_columns": tables_and_columns, "suggested_joins": suggested_joins}, justification
                )
            else:
                error_message += f"No JSON object found in the response {text}"
                logger.error(error_message)
                raise json.JSONDecodeError(error_message, text, 0)
        except json.JSONDecodeError as e:
            error_message += f"Error decoding JSON during retrieval graph processing"
            logger.exception(error_message)
            raise Exception(error_message)
        except Exception as e:
            error_message += parse_error_messages(str(e)) or str(e)
            logger.exception(error_message)
            fallback_enabled = is_fallback_enabled(decision_llm)
            if first_attempt and fallback_enabled:
                return self.get_retrieval_graph(decision_llm, chat_history, user_input, conversation_params, False)
            raise Exception(error_message)

    def get_retrieval_query(
        self,
        decision_llm: DKULLM,
        conversation_params: ConversationParams,
        chat_history: List[LlmHistory],
        user_input: str,
        failed_response: dict = {},
        completion: Optional[DSSLLMCompletionQuery] = None,
        first_attempt: bool = True,
    ):
        error_message = f"{DEFAULT_DECISIONS_GENERATION_ERROR} SQL query: "
        justification = ""
        if self.chosen_tables_and_columns == "":
            return None, self.graph_justification

        text = ""
        try:
            query_user_prompt = (
                f"Today's date and time: {datetime.now().strftime(PROMPT_DATE_FORMAT)} \n " + self.query_user_prompt
            )
            if not failed_response:
                # We only log this once not on every failed attempt
                logger.debug(f""" DB query FINAL PROMPT:
graph_system_prompt: {self.query_system_prompt}
user_prompt: {query_user_prompt}
            """)
            completion = get_llm_completion(decision_llm)
            if not first_attempt:
                completion = get_fallback_completion(completion)
            completion.with_message(self.query_system_prompt, role="system")
            completion = add_history_to_completion(completion, chat_history)
            completion.with_message(query_user_prompt, role="user")
            completion.with_message(user_input, role="user")
            resp: DSSLLMCompletionResponse
            if failed_response:
                logger.debug("Correcting query prompt ...")
                resp = self.fix_retrieval_query(decision_llm=decision_llm, user_input=user_input, chat_history=chat_history)
            else:
                logger.debug("Creating initial query prompt ...")
                resp = completion.execute()
            handle_response_trace(resp)
            text = str(resp.text)
            if not text and resp.errorMessage:
                error_message += str(resp.errorMessage)
                raise Exception(error_message)
            logger.debug(f"Generated response for DB query: {text}")
            match = re.search(r"\{.*\}", text, re.DOTALL)
            if match:
                json_str = match.group(0)
                response_json = json.loads(json_str)

                with_ = response_json.get("with")
                select_list = response_json.get("selectList")
                from_ = response_json.get("from")
                join = response_json.get("join")
                where_ = response_json.get("where")
                grp_by = response_json.get("groupBy")
                having = response_json.get("having")
                order_by = response_json.get("orderBy")
                limit = response_json.get("limit")
                justification = response_json.get("justification")

                logger.debug(
                    f"with: {with_}, select list: {select_list}, from: {from_}, join: {join}, where: {where_}, group by: {grp_by}, having: {having}, order_by: {order_by}, limit: {limit}, Justification: {justification}"
                )

                if all(not q for q in [with_, select_list, from_, join, where_, grp_by, having, order_by, limit]):
                    return None, justification

                query_dict = {
                    "with": with_,
                    "selectList": select_list,
                    "from": from_,
                    "join": join,
                    "where": where_,
                    "groupBy": grp_by,
                    "having": having,
                    "orderBy": order_by,
                    "limit": limit,
                }
                return query_dict, justification
            else:
                error_message += f"No JSON object found in the response"
                logger.error(error_message)
                raise json.JSONDecodeError(error_message, text, 0)
        except json.JSONDecodeError as e:
            error_message += f"Error decoding JSON during retrieval graph processing"
            logger.exception(error_message)
            raise Exception(error_message)
        except Exception as e:
            error_message += parse_error_messages(str(e)) or str(e)
            logger.exception(error_message)
            fallback_enabled = is_fallback_enabled(decision_llm)
            if first_attempt and fallback_enabled:
                return self.get_retrieval_query(decision_llm, conversation_params, chat_history, user_input, failed_response, completion, False)
            raise Exception(error_message)

    def fix_retrieval_query(self, decision_llm: DKULLM, user_input: str, chat_history: List[LlmHistory]) -> DSSLLMCompletionResponse:
        logger.debug(f"DB query: attempting to fix previous errors  {self.previous_sql_errors}")
        formatted_errors = ""
        for idx, error_response in enumerate(self.previous_sql_errors):
            formatted_errors += f"""Attempt {idx}:
            Response:
            {error_response["response"]}
            Query:
            {error_response["query"]}
            Error returned
            {error_response["error"]}
            ---------------------------------
            """
        attempts_history = f"""
        This is the attempt history:
        {formatted_errors}
        """
        error_user_prompt = self.error_user_prompt.format(attempts_history=attempts_history)
        logger.debug(f"""
            Error correction prompt:
            error_system_prompt: {self.error_system_prompt}
            error_user_prompt: {error_user_prompt}
        """)
        completion: DSSLLMCompletionQuery = get_llm_completion(decision_llm)
        completion.with_message(self.error_system_prompt, role="system")
        completion = add_history_to_completion(completion, chat_history)
        completion.with_message(error_user_prompt, role="user")
        completion.with_message(user_input, role="user")

        resp: DSSLLMCompletionResponse = completion.execute()
        return resp
