import json
import logging
from typing import Union, Dict, Tuple, Any, Optional
from collections.abc import Callable

from elasticsearch import UnsupportedProductError
from elasticsearch.helpers.vectorstore import DenseVectorStrategy, RetrievalStrategy
from langchain_core.documents import Document
from langchain_elasticsearch import ElasticsearchStore

from dataiku.core.vector_stores.dku_vector_store import DkuRemoteVectorStore, logger
from dataiku.core.vector_stores.vector_store_document_filter import DKU_SECURITY_TOKEN_META_PREFIX, VectorStoreDocumentFilter, ColumnType
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META
from dataiku.llm.types import RetrievableKnowledge, BaseVectorStoreQuerySettings, SearchType
from dataikuapi.dss.admin import DSSConnectionInfo
from dataikuapi.dss.langchain import DKUEmbeddings


class ElasticSearchVectorStore(DkuRemoteVectorStore):

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str, connection_info_retriever: Callable[[str], DSSConnectionInfo]):
        self.elasticsearch_url: str = ""
        self.auth_kwargs: Dict = {}
        self.version: Tuple[int, int] = (0, 0)
        super(ElasticSearchVectorStore, self).__init__(kb, exec_folder, connection_info_retriever)
        self.document_filter = ElasticSearchVectorStoreDocumentFilter(self.metadata_column_type)

    def set_index_name(self, index_name: str) -> None:
        self.index_name = index_name.lower()

    def init_connection(self) -> None:
        connection_params = self.connection_info_retriever(self.connection_name).get_params()
        auth_type = connection_params['authType']
        self.auth_kwargs = {}
        scheme = "http"
        if connection_params['ssl']:
            scheme = "https"
            if connection_params['trustAnySSLCertificate']:
                self.auth_kwargs['es_params'] = {'verify_certs': False}

        self.elasticsearch_url = "{}://{}:{}".format(scheme, connection_params['host'], connection_params['port'])
        if auth_type == "NONE":
            pass
        elif auth_type == "PASSWORD":
            self.auth_kwargs['es_user'] = connection_params['username']
            self.auth_kwargs['es_password'] = connection_params['password']
        elif auth_type == "OAUTH2_APP":
            raise NotImplementedError("OAuth v2.0 is not supported for Knowledge Banks")
        elif auth_type in ["AWS_KEYPAIR", "AWS_ENVIRONMENT", "AWS_STS", "AWS_CUSTOM"]:
            # This error will be caught by the VectorStoreFactory and trigger the use of OpenSearch implementation instead
            raise UnsupportedProductError("'{}' authentication is only supported with OpenSearch".format(auth_type), {},{"version": {"distribution": "opensearch"}})
        else:
            raise ValueError("Unknown authentication type: " + auth_type)

        self.version = self.get_version()

    def get_rrf_settings(self, query_settings: BaseVectorStoreQuerySettings) -> Union[Dict, bool]:
        advanced_reranking = query_settings.get("useAdvancedReranking", False)
        if advanced_reranking:
            k = self._get_search_kwargs(query_settings)["k"]
            return {
                "rank_constant": query_settings.get("rrfRankConstant", 60),
                "rank_window_size": query_settings.get("rrfRankWindowSize", k)
            }
        return False

    def _get_db_kwargs(self, query_settings: BaseVectorStoreQuerySettings) -> Dict:
        hybrid = query_settings.get("searchType") == "HYBRID"
        rrf_settings = self.get_rrf_settings(query_settings)
        if self.version < (8, 16) and rrf_settings:
            raise ValueError("RRF is only supported with ElasticSearch v8.16+")
        strategy = self.get_strategy(hybrid, rrf_settings)
        return {"strategy": strategy}

    @staticmethod
    def _get_search_type(query_settings: BaseVectorStoreQuerySettings) -> str:
        search_type = query_settings.get("searchType")
        if search_type == "HYBRID":
            search_type = "SIMILARITY"
        return DkuRemoteVectorStore.dss_search_type_to_langchain_search_type(search_type)

    def include_similarity_score(self, query_settings: BaseVectorStoreQuerySettings) -> bool:
        search_type = query_settings.get("searchType")
        return super().include_similarity_score(query_settings) and search_type != "HYBRID"

    def get_db(self, embeddings: Optional[DKUEmbeddings], allow_creation: bool = False, **kwargs: Any) -> ElasticsearchStore:
        """
        Instantiate the ElasticsearchStore db object

        :param embeddings: embeddings used to create the vector store
        :param allow_creation:  unused
        :param kwargs: if the `strategy` is provided, it will be used to instantiate the db object instead of the default one
        :rtype: ElasticsearchStore
        """
        # todo should check if index already exist to raise an error if allow_creation=false first (langchain always create it by default if unfound)

        # Enforce correct strategy instantiation if none present in kwargs
        kwargs["strategy"] = kwargs.get("strategy", self.get_strategy())

        db = ElasticsearchStore(
            es_url=self.elasticsearch_url,
            index_name=self.index_name,
            embedding=embeddings,
            **{**self.auth_kwargs, **kwargs}
        )
        return db

    def get_strategy(self, hybrid: bool = False, rrf: Union[Dict, bool] = False) -> RetrievalStrategy:
        # ES versions < v8.4 does not support the KNN query (default strategy used in `ElasticsearchStore`),
        # so we need to use a less efficient strategy based on custom script scoring
        if self.version < (8, 4):
            logger.info("Using DenseVectorScriptScoreStrategy instead of KNN query params.")
            from elasticsearch.helpers.vectorstore import DenseVectorScriptScoreStrategy
            return DenseVectorScriptScoreStrategy()
        else:
            return DenseVectorStrategy(hybrid=hybrid, rrf=rrf)

    def get_version(self) -> Tuple[int, int]:
        # We retrieve the db first so that we get a chance to raise an UnsupportedProductError
        # that will be caught by the VectorStoreFactory if we are connecting to an OpenSearch
        db = ElasticsearchStore(index_name=self.index_name, es_url=self.elasticsearch_url, **self.auth_kwargs)
        try:
            version_number = db.client.info()["version"]["number"]
            logger.info("Connected to ElasticSearch v{}".format(version_number))
            v = version_number.split(".")
            return int(v[0]), int(v[1])
        # Reraise the exception as a ConnectionError
        except Exception as e:
            raise ConnectionError("Could not retrieve ElasticSearch version: {}".format(e))

    def clear_index(self) -> None:
        self.get_db(None).client.indices.delete(
            index=self.index_name,
            ignore_unavailable=True,
            allow_no_indices=True
        )
        logger.info("Cleared ElasticSearch index {}".format(self.index_name))


class ElasticSearchVectorStoreDocumentFilter(VectorStoreDocumentFilter):
    """ ElasticSearch document filter

    See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-filter-context.html#filter-context
    and https://python.langchain.com/docs/integrations/vectorstores/elasticsearch/#query-vector-store
    """
    def convert(self, clause: Dict) -> Dict:
        if clause["operator"] == "EQUALS":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.String:
                return {"match": {f'metadata.{clause["column"]}.keyword':  self._coerce_value(clause)}}
            elif column_type in [ColumnType.Decimal, ColumnType.Integer] or clause["column"].startswith(DKU_SECURITY_TOKEN_META_PREFIX):
                return {"term": {f'metadata.{clause["column"]}': self._coerce_value(clause)}}
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "NOT_EQUALS":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.String:
                return {"bool": {"must_not": {"match": {f'metadata.{clause["column"]}.keyword': self._coerce_value(clause)}}}}
            elif column_type in [ColumnType.Decimal, ColumnType.Integer]:
                return {"bool": {"must_not": {"term": {f'metadata.{clause["column"]}': self._coerce_value(clause)}}}}
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "GREATER_THAN":
            return {"range": {f'metadata.{clause["column"]}': {"gt": self._coerce_value(clause)}}}
        elif clause["operator"] == "LESS_THAN":
            return {"range": {f'metadata.{clause["column"]}': {"lt": self._coerce_value(clause)}}}
        elif clause["operator"] == "GREATER_OR_EQUAL":
            return {"range": {f'metadata.{clause["column"]}': {"gte": self._coerce_value(clause)}}}
        elif clause["operator"] == "LESS_OR_EQUAL":
            return {"range": {f'metadata.{clause["column"]}': {"lte": self._coerce_value(clause)}}}
        elif clause["operator"] == "IN_ANY_OF":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.String:
                return {"bool": {"should": [{"match": {f'metadata.{clause["column"]}.keyword': value}} for value in self._coerce_value(clause)]}}
            elif column_type in [ColumnType.Decimal, ColumnType.Integer]:
                return {"terms": {f'metadata.{clause["column"]}': self._coerce_value(clause)}}
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "IN_NONE_OF":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.String:
                return {"bool": {"must_not": [{"match": {f'metadata.{clause["column"]}.keyword': value}} for value in self._coerce_value(clause)]}}
            elif column_type in [ColumnType.Decimal, ColumnType.Integer]:
                return {"bool": {"must_not": {"terms": {f'metadata.{clause["column"]}': self._coerce_value(clause)}}}}
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "AND":
            return {"bool": {"filter": [self.convert(x) for x in clause["clauses"]]}}
        elif clause["operator"] == "OR":
            return {"bool": {"should": [self.convert(x) for x in clause["clauses"]]}}
        else:
            raise Exception("Unsupported filter operator for ElasticSearch: %s" % clause["operator"])

    def _and_filter(self, converted_clause_a: Dict, converted_clause_b: Dict) -> Dict:
        return {"bool": {"filter": [converted_clause_a, converted_clause_b]}}

    def _get_filter_field_name(self) -> str:
        return "filter"

    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add a boolean field per token
        """
        if DKU_SECURITY_TOKENS_META in document.metadata:
            security_tokens = json.loads(document.metadata[DKU_SECURITY_TOKENS_META])
            for token in security_tokens:
                document.metadata[self._get_security_token_with_prefix(token)] = "true"
            logging.info("Updated metadata to %s" % self._sanitize_metadata_for_print(document.metadata))
        return document

    def _get_security_token_check_clause(self, token: str) -> Dict:
        return {
            "operator": "EQUALS",
            "column": self._get_security_token_with_prefix(token),
            "value": "true",
        }

    @staticmethod
    def _get_security_token_with_prefix(token: str) -> str:
        return DKU_SECURITY_TOKEN_META_PREFIX + token
