import json
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union

from answers.backend.models.base import METADATA_REPLACEMENTS, EmbeddingRecipeType
from answers.backend.services.sources.sources_builder_specific_answers import generate_sources_from_source_documents
from answers.backend.utils.config_utils import get_retriever_info
from answers.backend.utils.knowledge_banks_params import KnowledgeBanksParams
from answers.backend.utils.knowledge_filters import process_filters_for_db
from answers.backend.utils.langchain_document_utils import (
    from_documents_to_multipart_context,
    from_documents_to_structured_context,
)
from answers.solutions.chains.kb.description_based_kb_agent import DescriptionBasedKBAgent
from answers.solutions.chains.kb.generic_kb_agent import GenericKBAgent
from answers.solutions.chains.kb.simple_kb_agent import SimpleKBAgent
from answers.solutions.knowledge_bank import (
    get_knowledge_bank_info,
    get_knowledge_bank_retriever,
)
from answers.solutions.prompts.citations import CITATIONS_PROMPT
from common.backend.constants import DEFAULT_DECISIONS_GENERATION_ERROR, PROMPT_DATE_FORMAT
from common.backend.models.base import (
    ConversationParams,
    LLMContext,
    LlmHistory,
    LLMStep,
    MediaSummary,
    RetrievalSummaryJson,
)
from common.backend.utils.config_utils import resolve_webapp_param
from common.backend.utils.context_utils import LLMStepName
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.llm_utils import (
    get_llm_capabilities,
    get_llm_completion,
    handle_prompt_media_explanation,
    handle_response_trace,
    parse_error_messages,
)
from common.backend.utils.prompt_utils import append_user_profile_to_prompt
from common.llm_assist.fallback import get_fallback_completion, is_fallback_enabled
from common.llm_assist.logging import logger
from common.solutions.chains.generic_answers_chain import GenericAnswersChain
from dataiku.core.knowledge_bank import MultipartContext
from dataiku.langchain.dku_llm import DKULLM
from dataikuapi.dss.llm import DSSLLMCompletionResponse
from langchain.schema.document import Document


class KBRetrievalChain(GenericAnswersChain):
    def __init__(
        self,
        llm: DKULLM,
        decision_llm: DKULLM,
        query: str,
        filters: Optional[Dict[str, List[Any]]],
        chat_has_media: bool = False,
    ):
        super().__init__()
        self.__llm = llm
        self.__decision_llm = decision_llm
        self.__kbs_params = KnowledgeBanksParams()
        self.__act_like_prompt = ""
        self.__system_prompt = ""
        self.__prompt_with_media_explanation = chat_has_media
        self.__knowledge_banks_to_use: List[str] = []
        self.__input_filters = filters.copy() if filters else {}
        self.__filters = self.__create_kb_filters(filters, query)
        self.__retrieval_query_agent: GenericKBAgent
        if self.webapp_config.get("enable_smart_usage_of_kb", True):
            self.__retrieval_query_agent = DescriptionBasedKBAgent(chat_has_media=chat_has_media)
        else:
            self.__retrieval_query_agent = SimpleKBAgent(chat_has_media=chat_has_media)

    @property
    def act_like_prompt(self) -> str:
        return self.__act_like_prompt

    @property
    def system_prompt(self) -> str:
        return self.__system_prompt

    @property
    def input_filters(self) -> Any:
        return self.__input_filters

    @input_filters.setter
    def input_filters(self, filters: Dict[str, List[Any]]) -> None:
        self.__input_filters = filters

    @property
    def filters(self) -> Any:
        return self.__filters

    @property
    def knowledge_banks_to_use(self) -> List[str]:
        return self.__knowledge_banks_to_use

    @property
    def retrieval_query_agent(self) -> GenericKBAgent:
        return self.__retrieval_query_agent

    @property
    def llm(self) -> DKULLM:
        return self.__llm

    @property
    def decision_llm(self) -> DKULLM:
        return self.__decision_llm

    @property
    def chain_purpose(self) -> str:
        return LLMStepName.KB_ANSWER.value

    def __verify_filters(self, filters: Dict[str, List[Any]]) -> Optional[Dict[str, List[Any]]]:
        if not filters or not self.__kbs_params.filters_config:
            return None
        valid_metadata = []
        for metadata in self.__kbs_params.filters_config["filter_metadata"]:
            if metadata_replacement := METADATA_REPLACEMENTS.get(metadata):
                valid_metadata.append(metadata_replacement)
            else:
                valid_metadata.append(metadata)
        verified_metadata = [metadata for metadata in filters.keys() if metadata in valid_metadata]
        verified_filters = {}
        for metadata in verified_metadata:
            verified_values = [
                value for value in filters[metadata] if value in self.__kbs_params.filters_config["filter_options"][metadata]
            ]
            verified_filters[metadata] = verified_values

        return verified_filters

    def __create_kb_filters(self, filters: Optional[Dict[str, List[Any]]], query: str) -> Dict[str, Any]:
        vector_db_type = self.__kbs_params.knowledge_bank_vector_db_types[0]
        advanced_filtering_possible = True if dataiku_api.webapp_config.get("knowledge_sources_filters", []) else False
        enable_auto_filtering = resolve_webapp_param(
            param_name="knowledge_enable_auto_filtering",
            default_value=False,
            advanced_mode_enabled=advanced_filtering_possible,
        )
        logger.debug(f'filters prior to the verification: {filters}')
        if vector_db_type and filters and len(filters) > 0:
            verified_filters = process_filters_for_db(
                self.__verify_filters(filters=filters), vector_db_type
            )
            logger.debug(
                f"vector_db_type:{vector_db_type}   / filters: {filters} with verified_filters: {verified_filters}"
            )
            self.__input_filters = {self.__kbs_params.knowledge_bank_ids[0]: self.input_filters}
            return {self.__kbs_params.knowledge_bank_ids[0]: verified_filters}
        elif (
            (filters is None or len(filters) == 0)
            and self.__kbs_params.filters_config
            and enable_auto_filtering
        ):
            computed_filters = self.__auto_filters(query, self.__kbs_params.filters_config["filter_options"])
            logger.debug(f"computed_filters: {computed_filters}")
            self.__input_filters = {self.__kbs_params.knowledge_bank_ids[0]: self.input_filters}
            return {
                self.__kbs_params.knowledge_bank_ids[0]: process_filters_for_db(
                    filters=computed_filters, vector_db_type=vector_db_type
                )
            }
        else:

            self.input_filters: Dict[str, List[Any]] = {self.__kbs_params.knowledge_bank_ids[0]: []}
            return {self.__kbs_params.knowledge_bank_ids[0]: {}}

    def __extract_json(self, response_text: str) -> Dict[str, Any]:
        # Find all characters that could be the start or end of a JSON object
        json_objects = re.findall(r"{.*?}", response_text)
        # Find the longest string that could be a JSON object,
        # since it's most likely to be the correct one
        longest_json = max(json_objects, key=len, default="{}")
        # Convert the string back to a dictionary (JSON object)
        try:
            json_data: Dict[str, Any] = json.loads(longest_json)
        except json.JSONDecodeError as e:
            logger.exception(f"Error parsing filters json: {e}")
            # In case JSON decoding fails, just return an empty dictionary
            json_data = {}

        return json_data

    def __auto_filters(
        self, user_query: str, possible_filters: Dict[str, List[str]], first_attempt: bool = True
    ) -> Union[Dict[str, List[str]], None]:
        """Analyzes the user query and identifies the relevant filters and their values.

        Args:
            user_query (str): The user query.
            possible_filters (List[str]): A dictionary containing the possible filters and their values.
            first_attempt (bool): A flag indicating whether this is the first attempt to generate filters.

        Returns:
            Dict[str, List[str]]: A dictionary containing the relevant filters and their values.

        """
        error_message = f"{DEFAULT_DECISIONS_GENERATION_ERROR} Auto filter: "
        logger.debug("Auto Filtering On")
        datetime_now = datetime.now().strftime(PROMPT_DATE_FORMAT)
        prompt = (
            f"Today's date and time: {datetime_now} "
            f"Examine the user query and Identify which of the following possible filter values are relevant: "
            "# POSSIBLE FILTERS: "
            f"{possible_filters}. "
            "# INSTRUCTIONS "
            "1 - Return only a JSON object of the relevant filters directly, with no additional explanations, text decorations, "
            "or markdown. Ensure the JSON object is in a format that can be parsed directly by a JSON parser without needing "
            "any preprocessing. "
            "2- If no filters apply or if you are not sure about the relevant filters, return an empty JSON object: {}. "
            "3- Remember to strictly return the relevant filters JSON object with no additional text or markdown. "
            "### EXAMPLES: "
            "- Example 1: "
            "user query: What are the revenues for Q1 2023 and Q2 in 2023? "
            'possible filters: {"file":["2023 Q2.pdf", "2023 Q1.pdf","2024 Q2.pdf"]} '
            'Expected Answer is: {"file":["2023 Q2.pdf", "2023 Q1.pdf"]}. '
            "- Example 2: "
            "user query: How many vacation days do we have in California. "
            'possible filters: {"location":["California", "New York", "Texas", "Paris"], "tags": ["source","name","url"], "date": ["2022","2023","2024"]} '
            'Expected Answer is: {"location":["California"]}. '
            "- Example 3: "
            "user query: What are the sales for the 2023 last quarter for Microsoft and Samsung? "
            'possible filters: {"company":["Apple", "Microsoft", "Samsung"], "quarter": ["Q1","Q2","Q3","Q4"], "year": ["2022","2023","2024"]} '
            'Expected Answer is: {"company":["Microsoft", "Samsung"], "quarter": ["Q4"], "year": ["2023"]}. '
            "- Example 4: "
            "user query: How many employees are in the marketing department? "
            'possible filters: {"location": ["Paris", "New York", "London"]} '
            "Expected Answer is: {}. "
        )
        # TODO: if we want to append the user profile to the prompt, we can do it here

        logger.debug(f"Auto filters prompt: {prompt}")
        response = None
        try:
            completion = get_llm_completion(self.decision_llm)
            if not first_attempt:
                completion = get_fallback_completion(completion)
            completion.with_message(prompt, role="system")
            completion.with_message(user_query, role="user")
            llm_response: DSSLLMCompletionResponse = completion.execute()
            response = str(llm_response.text)
            handle_response_trace(response)
            if not response and llm_response.errorMessage:
                error_message += str(llm_response.errorMessage)
                raise Exception(llm_response.errorMessage)
            logger.debug(f"Auto filters llm response: '{response}'")

            if isinstance(response, dict):
                relevant_filters = response
            else:
                relevant_filters = self.__extract_json(response)
            logger.debug(f"Extracted filters  {relevant_filters}")
            validated_filters = self.__verify_filters(relevant_filters)

            return validated_filters
        except json.JSONDecodeError as e:
            error_message += f"Error decoding JSON during auto filter processing"
            logger.exception(error_message)
            raise Exception(error_message)
        except Exception as e:
            error_message += parse_error_messages(str(e)) or str(e)
            logger.exception(error_message)
            # checking if we can use a fallback LLM
            fallback_enabled = is_fallback_enabled(self.decision_llm)
            if first_attempt and fallback_enabled:
                return self.__auto_filters(user_query, possible_filters, False)
            raise Exception(error_message)

    def __get_retriever(self) -> Any:
        knowledge_bank_id = self.knowledge_banks_to_use[0]
        knowledge_bank_filter = self.filters[knowledge_bank_id]
        return get_knowledge_bank_retriever(
            knowledge_bank_id,
            knowledge_bank_filter,
            self.__kbs_params.all_knowledge_bank_parameters["retrieval_parameters"]["search_type"],
        )  # type: ignore[no-any-return]

    def load_role_and_guidelines_prompts(self, params: ConversationParams):
        if self.knowledge_banks_to_use:
            act_like_prompt = self.webapp_config.get("knowledge_bank_prompt", "")
            system_prompt = dataiku_api.webapp_config.get(
                "knowledge_bank_system_prompt",
                "Given the following specific context and the conversation between a user and an assistant, please give a short answer to the question at the end, If you don't know the answer, just say that you don't know, don't try to make up an answer.",
            )
        else:
            act_like_prompt, system_prompt = self.load_default_role_and_guidelines_prompts()
        user_profile = params.get("user_profile", None)
        include_full_user_profile = bool(self.webapp_config.get("include_user_profile_in_KB_prompt", False))
        system_prompt = append_user_profile_to_prompt(
            system_prompt=system_prompt, user_profile=user_profile, include_full_user_profile=include_full_user_profile
        )
        system_prompt = handle_prompt_media_explanation(
            system_prompt=system_prompt, has_media=self.__prompt_with_media_explanation
        )
        self.__act_like_prompt = act_like_prompt
        self.__system_prompt = system_prompt

    def get_computed_system_prompt(self, params: ConversationParams) -> str:
        datetime_now = datetime.now().strftime(PROMPT_DATE_FORMAT)
        enable_llm_citations = self.webapp_config.get("enable_llm_citations", False)
        return rf"""
        Today's date and time: {datetime_now}
        {self.act_like_prompt}
        
        {self.system_prompt}

        {CITATIONS_PROMPT if enable_llm_citations else ""}
        """

    def get_computing_prompt_step(self) -> LLMStep:
        if self.knowledge_banks_to_use:
            return LLMStep.COMPUTING_PROMPT_WITH_KB
        else:
            return LLMStep.COMPUTING_PROMPT_WITHOUT_RETRIEVAL

    def get_querying_step(self, params: ConversationParams) -> LLMStep:
        if self.knowledge_banks_to_use:
            step = LLMStep.QUERYING_LLM_WITH_KB
        else:
            step = LLMStep.QUERYING_LLM_WITHOUT_RETRIEVAL
        return step

    def finalize_streaming(
        self,
        params: ConversationParams,
        answer_context: Union[str, Dict[str, Any], List[str]],
    ) -> RetrievalSummaryJson:
        user_profile = params.get("user_profile", None)

        # Send sources and filters at the end of the streaming
        return self.get_as_json(
            answer_context, user_profile=user_profile, uploaded_docs=params.get("media_summaries", [])
        )

    def finalize_non_streaming(
        self,
        params: ConversationParams,
        answer_context: Union[str, Dict[str, Any], List[str]],
    ) -> RetrievalSummaryJson:
        return self.finalize_streaming(params=params, answer_context=answer_context)

    def get_retrieval_context(
        self, params: ConversationParams
    ) -> Tuple[Optional[Optional[Union[MultipartContext, str]]], Dict[str, Any]]:
        kb_info = get_knowledge_bank_info(project=dataiku_api.default_project, knowledge_bank_id=dataiku_api.webapp_config.get("knowledge_bank_id"))
        embedding_recipe_type = kb_info["embedding_recipe_type"]
        
        kb_metadata_schema = kb_info["metadata_schema"]
        is_multimodal = embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS

        user_query = params.get("user_query", "")
        kb_query = params.get("kb_query", "")
        kb_query = kb_query if kb_query is not None else ""
        answer_context: Dict[str, Any] = {}
        retrieved_context: Union[MultipartContext, str] = None
        answer_context["generated_question"] = user_query
        retriever = self.__get_retriever()
        try:
            docs = retriever.invoke(kb_query)
            logger.debug(f"KB query '{kb_query}' retrieved docs {len(docs)} docs")
            if is_multimodal and get_llm_capabilities().get("multi_modal", False):
                logger.debug(f"Multimodal retrieval context is : {docs}")
                retrieved_context = from_documents_to_multipart_context(
                    docs, str(dataiku_api.webapp_config.get("knowledge_bank_id"))
                )
                answer_context["context"] = retrieved_context
            else:
                structured_documents = from_documents_to_structured_context(
                    docs, kb_metadata_schema
                )
                answer_context["context"] = structured_documents
                retrieved_context = json.dumps(structured_documents, ensure_ascii=False, indent=4)

            answer_context["source_documents"] = docs
            return retrieved_context, answer_context
        except Exception as e:
            retrieved_context = f"Unable to run knowledge bank query: {e}"
            return retrieved_context, answer_context


    def get_as_json(
        self,
        answer_context: Union[str, Dict[str, Any], List[str]],
        user_profile: Optional[Dict[str, Any]] = None,
        uploaded_docs: Optional[List[MediaSummary]] = None,
    ) -> RetrievalSummaryJson:
        llm_context: LLMContext = {}  # type: ignore
        llm_context["selected_retrieval_info"] = get_retriever_info(config=self.webapp_config)

        def handle_filters(filters):
            if filters is not None and isinstance(filters, dict):
                result = (
                    filters.get(self.knowledge_banks_to_use[0], None)
                    if filters and self.knowledge_banks_to_use
                    else None
                )
                return None if result == {} else result
            return filters

        self.__input_filters = handle_filters(self.__input_filters)

        if self.knowledge_banks_to_use:
            llm_context["llm_kb_selection"] = self.knowledge_banks_to_use
        if user_profile:
            llm_context["user_profile"] = user_profile
        # TODO: do we need this case ?
        if isinstance(answer_context, str):
            return {
                "answer": answer_context,
                "sources": [],
                "filters": self.input_filters,
                "knowledge_bank_selection": self.knowledge_banks_to_use,
            }
        llm_context["uploaded_docs"] = (
            [
                {
                    "original_file_name": str(uploaded_doc.get("original_file_name")),
                    "file_path": str(uploaded_doc.get("file_path")),
                    "metadata_path": str(uploaded_doc.get("metadata_path")),
                }
                for uploaded_doc in uploaded_docs
            ]
            if uploaded_docs
            else []
        )
        if isinstance(answer_context, dict):
            answer = answer_context.get("answer", "")
            source_documents: List[Document] = answer_context.get("source_documents", [])
            formatted_sources = generate_sources_from_source_documents(source_documents)

            return {
                "answer": answer,
                "sources": formatted_sources,
                "filters": self.input_filters,
                "llm_context": llm_context,
            }
        logger.error(f"Generated answer type not supported. This should not happen. {answer_context}")
        return {}

    def create_query_from_history_and_update_params(
        self, chat_history: List[LlmHistory], user_query: str, params: ConversationParams
    ):
        retrieval_query, justification = self.retrieval_query_agent.get_retrieval_query(
            decision_llm=self.decision_llm, chat_history=chat_history, user_input=user_query, conversation_params=params
        )
        params["justification"] = justification
        params["kb_query"] = retrieval_query
        logger.debug(f"Computed Retrieval query from the user query: [{user_query}], is [{retrieval_query}]")
        self.__knowledge_banks_to_use = [] if params["kb_query"] is None else self.__kbs_params.knowledge_bank_ids
        params["knowledge_bank_selection"] = self.knowledge_banks_to_use
        if self.__knowledge_banks_to_use == []:
            logger.warn("No knowledge bank has been selected but knowledge bank is enabled.")
        logger.debug(f"""Selected knowledge bank: {self.knowledge_banks_to_use}
        retrieval_enabled: {params["retrieval_enabled"]}""")
        return params
