import json
import re
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Tuple

from common.backend.constants import DEFAULT_DECISIONS_GENERATION_ERROR, PROMPT_DATE_FORMAT
from common.backend.models.base import ConversationParams, LlmHistory, MediaSummary
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.llm_utils import (
    add_history_to_completion,
    append_summaries_to_completion_msg,
    get_llm_completion,
    handle_response_trace,
    parse_error_messages,
)
from common.llm_assist.fallback import get_fallback_completion, is_fallback_enabled
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import DSSLLMCompletionQueryMultipartMessage, DSSLLMCompletionResponse


# TODO to refactor/inherit from GenericDecisionJSONChain
class GenericKBAgent(ABC):
    # Abstract class for Knowledge Bank Agents
    # Implement the prompt property in the child class
    @property
    @abstractmethod
    def system_prompt(self):
        return self._system_prompt

    @property
    @abstractmethod
    def user_prompt(self):
        return self._user_prompt

    def get_retrieval_query(self, decision_llm, chat_history: List[LlmHistory], user_input: str, conversation_params: ConversationParams, first_attempt: bool = True) -> Tuple[str, str]:
        error_message = f"{DEFAULT_DECISIONS_GENERATION_ERROR} KB decision: "
        justification = ""
        user_profile = conversation_params.get("user_profile", {}) or {}
        include_user_profile = (
            dataiku_api.webapp_config.get("include_user_profile_in_KB_prompt", False)
            is True
        )
        logger.debug(f" include_user_profile: {include_user_profile}")
        media_summaries: List[MediaSummary] = conversation_params.get("media_summaries") or []
        previous_media_summaries: List[MediaSummary] = conversation_params.get("previous_media_summaries") or []
        summaries = media_summaries + previous_media_summaries

        try:
            user_prompt = f"Today's date and time: {datetime.now().strftime(PROMPT_DATE_FORMAT)} \n " + self.user_prompt
            if include_user_profile:
                user_prompt = user_prompt.format(user_profile=json.dumps(user_profile))
            logger.debug(f"""Knowledge Bank retrieval query FINAL PROMPT:
system_prompt: {self.system_prompt}
user_prompt: {user_prompt}
            """)
            completion = get_llm_completion(decision_llm)
            if not first_attempt:
                completion = get_fallback_completion(completion)
            completion.with_message(self.system_prompt, role="system")
            completion = add_history_to_completion(completion, chat_history)
            completion.with_message(user_prompt, role="user")
            if summaries:
                msg: DSSLLMCompletionQueryMultipartMessage = completion.new_multipart_message(role="user")
                append_summaries_to_completion_msg(summaries, msg)
            completion.with_message(user_input, role="user")
            resp: DSSLLMCompletionResponse = completion.execute()
            handle_response_trace(resp)
            text = str(resp.text)
            logger.debug(f"Generated response for knowledge bank: {text}")
            if not text and resp.errorMessage:
                error_message += str(resp.errorMessage)
                raise Exception(error_message)
            match = re.search(r"\{.*\}", text, re.DOTALL)
            if match:
                json_str = match.group(0)
                response_json = json.loads(json_str)
                query = response_json.get("query", "")
                justification = response_json.get("justification")
                logger.debug(f"Query: {query}, Justification: {justification}")
                return query, justification
            else:
                error_message += f"No JSON object found in the response {text}"
                raise json.JSONDecodeError(error_message, text, 0)
        except json.JSONDecodeError as e:
            error_message += f"Error decoding JSON during retrieval graph processing"
            logger.exception(error_message)
            raise Exception(error_message)
        except Exception as e:
            error_message += parse_error_messages(str(e)) or str(e)
            logger.exception(error_message)
            # checking if we can use a fallback LLM
            fallback_enabled = is_fallback_enabled(decision_llm)
            if first_attempt and fallback_enabled:
                return self.get_retrieval_query(decision_llm, chat_history, user_input, conversation_params, False)
            raise Exception(error_message)
