import traceback
from enum import Enum
from typing import Any, Dict, List, Tuple, Union

import numpy as np
from answers.backend.models.base import VectorstoreDocument
from common.backend.utils.metas_utils import is_valid_enum_value
from common.backend.utils.sql_timing import log_execution_time
from common.llm_assist.logging import logger
from dataiku.core.knowledge_bank import KnowledgeBank


class DSSRetrieverSearchKwarg(str, Enum):
    K = "k"
    SCORE_THRESHOLD = "score_threshold"
    FILTER = "filter"


class DSSVectorstoreRetriever:
    """
    This class is designed get around the langchain constraints regarding the similarity score with threshold 
        at the moment of the `invoke` call (langchain retrievers force the `score_threshold` to be bounded between [0,1[ ).

    It similarly to langchain vectorstores, it has logics to pass retrieval search parameters and an `invoke` method.

    Practical use in 'def get_knowledge_bank_retriever' that is called in the KBRetrievalChain.__get_retriever.   
    """
    DEFAULT_K = 5
    MAX_CONTENT_CHARS_IN_LOGS = 100
    
    @log_execution_time
    def __init__(self, knowledge_bank: KnowledgeBank):
        self.vectorstore = knowledge_bank.as_langchain_vectorstore()
        self._search_kwargs: Dict[Union[str, DSSRetrieverSearchKwarg], Any] = {
            "k": self.DEFAULT_K
            }
        logger.debug(f"{self.__class__.__name__} initialized", log_conv_id=True)
        
    @property
    def search_kwargs(self):
        return self._search_kwargs
    
    def validate_search_kwargs(self, search_kwargs: Dict[Union[str, DSSRetrieverSearchKwarg], Any]):
        for key in search_kwargs.keys():
            if (key == DSSRetrieverSearchKwarg.K.value) and  (not isinstance(search_kwargs[key], int)):
                raise TypeError("'k' must be an integer")
            elif (key == DSSRetrieverSearchKwarg.SCORE_THRESHOLD.value) and (not isinstance(search_kwargs[key], (int, float))):
                raise TypeError("'score_threshold' must be a number")
            elif (key == DSSRetrieverSearchKwarg.FILTER.value) and (not isinstance(search_kwargs[key], dict)):
                raise TypeError("'filter' must be a dict")

            elif not is_valid_enum_value(key, DSSRetrieverSearchKwarg):
                log_message = f"The parameter '{key}' can't be used as a search parameter in the {self.__class__.__name__}"
                logger.exception(log_message, log_conv_id=True)
                raise Exception(log_message)
    
    def set_search_kwargs(self, search_kwargs: Dict[Union[str, DSSRetrieverSearchKwarg], Any]):
        if isinstance(search_kwargs, dict):
            self.validate_search_kwargs(search_kwargs)
            self._search_kwargs.update(search_kwargs)
        else:
            raise ValueError(f"The arguments passed, `{search_kwargs}` are of type `{type(search_kwargs)}`: a dictionary is expected.")
    
    @log_execution_time
    def invoke(self, query: str,)->List[VectorstoreDocument]:
        search_docs: List[VectorstoreDocument] = []
        all_docs_info = []
        all_docs_scores: List[float] = []
        excluded_docs_info = []
        if not self.search_kwargs["k"]:
            self.set_search_kwargs({"k": self.DEFAULT_K})
        k = self.search_kwargs["k"]
        score_threshold = self.search_kwargs.get("score_threshold", None)
        filter = self.search_kwargs.get("filter", None)
        opt_kwargs = {"filter": filter} if filter else {}
        logger.debug(f"Retrieving documents with `{self.__class__.__name__} (search_kwargs={self.search_kwargs})`", log_conv_id=True)
        try:
            initial_search_results: List[Tuple[VectorstoreDocument, float]] = self.vectorstore.similarity_search_with_relevance_scores(query=query, k=k, **opt_kwargs)
            for index, result in enumerate(initial_search_results):
                document = result[0]
                score = result[1]
                doc_info: Dict[str, Any] = {"index": index, "score": score}
                all_docs_scores.append(score)
                if score_threshold:
                    if score >= score_threshold:
                        search_docs.append(document)
                    else:
                        excluded_docs_info.append(
                        doc_info.copy()
                    )
                else:
                    search_docs.append(document)

                doc_info["metadata"] = document.metadata
                content_to_log = document.page_content
                if len(content_to_log) > self.MAX_CONTENT_CHARS_IN_LOGS:
                    content_to_log = f"{content_to_log[0: self.MAX_CONTENT_CHARS_IN_LOGS]} (only first {self.MAX_CONTENT_CHARS_IN_LOGS} characters displayed)]"
                doc_info["page_content"] = content_to_log
                all_docs_info.append(doc_info)

            all_docs_scores = np.array(all_docs_scores) # type: ignore
            score_metrics = f"Document scores stats: max={round(np.max(all_docs_scores), 3)} | mean={round(np.mean(all_docs_scores), 3)} | median={round(np.median(all_docs_scores), 3)} | min={round(np.min(all_docs_scores), 3)}"
            logger.debug(f"{len(all_docs_info)} documents retrieved from the vectorstore with the query '{query}': `{all_docs_info}`\n{score_metrics}", log_conv_id=True)
            if excluded_docs_info:
                log_message = f"{len(excluded_docs_info)} documents excluded due to a score under the threshold of '{score_threshold}' : "\
                    f"`{excluded_docs_info}`"
                logger.debug(log_message, log_conv_id=True)
            
            return search_docs
        except Exception as e:
            raise RuntimeError(f"RuntimeError in `{self.__class__.__name__}.invoke`: {traceback.format_exc()}")