import json
import logging
import os
from typing import Any, Optional, Dict, List, Union
from time import sleep

import pandas as pd
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_milvus import BM25BuiltInFunction
from langchain_milvus import Milvus
from pymilvus import MilvusClient
from pymilvus import MilvusException
from pymilvus.client.types import DataType
from pymilvus.exceptions import ConnectionConfigException

from dataiku.core.vector_stores.dku_vector_store import DkuLocalVectorStore, logger
from dataiku.core.vector_stores.vector_store_document_filter import VectorStoreDocumentFilter, ColumnType
from dataiku.langchain import DKUEmbeddings
from dataiku.langchain.metadata_generator import DKU_DOCUMENT_INFO
from dataiku.langchain.metadata_generator import DKU_MULTIMODAL_CONTENT
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META
from dataiku.llm.types import RetrievableKnowledge, BaseVectorStoreQuerySettings

logging.getLogger("langchain_milvus.vectorstores.milvus").setLevel(logging.ERROR)

MILVUS_DB_NAME = "milvus.db"
COLLECTION_NAME = "dku_milvus"

class MilvusLocalVectorStore(DkuLocalVectorStore):

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str):
        super(MilvusLocalVectorStore, self).__init__(kb, exec_folder, collection_name=COLLECTION_NAME)
        self.uri = os.path.join(self.exec_folder, MILVUS_DB_NAME)
        self.document_filter = MilvusVectorStoreDocumentFilter(self.metadata_column_type_and_meaning)
        self._client = None
        self._complete_metadata_columns()

    @property
    def client(self):
        if self._client is None:
            self._client = self._get_client_with_retry()
        return self._client

    def _get_client_with_retry(self):
        try:
            return MilvusClient(uri=self.uri)
        except (MilvusException, ConnectionConfigException):
            logger.warning(f"Milvus client failed to initialise, retrying in 5s")
            sleep(5)
            return MilvusClient(uri=self.uri)

    def _complete_metadata_columns(self):
        """Add missing metadata columns

        Milvus Lite complains if metadata is not provided or is null.
        """
        # Add dku-generated metadata columns for multimodal Kbs
        if self.kb.get("multimodalColumn"):
            self.metadata_column_type_and_meaning.update({DKU_MULTIMODAL_CONTENT: ("string", None), DKU_DOCUMENT_INFO: ("string", None)})

        reserved_fields = ['pk', 'text', 'vector', 'sparse']
        # If a metadata column was removed from `metadata_column_type`, we fetch it from the client
        if self._db_exists():
            collection_desc = self.client.describe_collection(COLLECTION_NAME)
            meta_fields = [
                f for f in collection_desc["fields"]
                if not f.get("is_primary") and f["name"] not in reserved_fields
            ]
            for field in meta_fields:
                if field["name"] not in self.metadata_column_type_and_meaning:
                    converted_type = "string"
                    if field["type"] == DataType.BOOL:
                        converted_type = "boolean"
                    elif field["type"] in [DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64]:
                        converted_type = "int"
                    elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
                        converted_type = "float"
                    elif field["type"] == DataType.ARRAY:
                        converted_type = "dku_array"
                    self.metadata_column_type_and_meaning[field["name"]] = (converted_type, None)

    @staticmethod
    def _get_search_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        kwargs = DkuLocalVectorStore._get_search_kwargs(query_settings)
        # Make sure we retrieve enough documents for hybrid search.
        # See https://github.com/langchain-ai/langchain-milvus/blob/libs/milvus/v0.2.1/libs/milvus/langchain_milvus/vectorstores/milvus.py#L1518
        kwargs["fetch_k"] = kwargs["k"]
        return kwargs

    @staticmethod
    def _get_db_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        return {"hybrid": query_settings.get("searchType") == "HYBRID"}

    @staticmethod
    def _get_search_type(query_settings: BaseVectorStoreQuerySettings) -> str:
        search_type = query_settings.get("searchType")
        if search_type == "HYBRID":
            # The search type when calling `as_retriever` must be similarity.
            # Hybrid search is determined by our `hybrid` kwarg in `get_db`
            search_type = "SIMILARITY"
        return DkuLocalVectorStore.dss_search_type_to_langchain_search_type(search_type)

    @staticmethod
    def _get_retriever_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        advanced_reranking = query_settings.get("useAdvancedReranking", False)
        if advanced_reranking:
            # TODO: allow customization of `k`
            # Use recommended value: https://milvus.io/docs/rrf-ranker.md#RRF-Ranker:~:text=While,a%20common%20choice
            k = 60
            return {
                "ranker_type": "rrf",
                "ranker_params": {"k": k}
            }

        return {}

    def _fix_metadata_in_results(self, results: List[tuple[Document, float]]) -> List[tuple[Document, float]]:
        for doc, _ in results:
            for key, val in doc.metadata.items():
                # Milvus returns google.protobuf.pyext._message.RepeatedScalarContainer for array columns
                # We need to convert them to list to be json serializable
                if type(val).__name__ == "RepeatedScalarContainer":
                    doc.metadata[key] = list(val)
        return results

    def search_with_scores(self, query: str, embeddings: DKUEmbeddings, query_settings: BaseVectorStoreQuerySettings,
                           additional_search_kwargs: Optional[Dict] = None) -> List[tuple[Document, float]]:
        search_kwargs = self._get_search_kwargs(query_settings)
        if isinstance(additional_search_kwargs, dict):
            search_kwargs = {**search_kwargs, **additional_search_kwargs}

        db_kwargs = self._get_db_kwargs(query_settings)
        retriever_kwargs = self._get_retriever_kwargs(query_settings)
        vectorstore_db = self.get_db(embeddings, **db_kwargs)

        if query_settings.get("searchType") == "HYBRID":
            # We avoid calling `similarity_search_with_relevance_scores` because no relevance function exists in hybrid mode.
            # The WeightedRanker of Milvus normalizes and combines the scores between the two types of search.
            results = vectorstore_db.similarity_search_with_score(query, **search_kwargs, **retriever_kwargs)
        else:
            results = vectorstore_db.similarity_search_with_relevance_scores(query, **search_kwargs, **retriever_kwargs)

        return self._fix_metadata_in_results(results)

    def get_db(self, embeddings: Embeddings, allow_creation: bool = False, **kwargs: Any) -> VectorStore:
        if self._db_exists():
            hybrid = kwargs.get("hybrid", False)
        elif allow_creation:
            # At creation time, the hybrid index is required so that we create the correct schema
            hybrid = True
        else:
            # DB doesn't exist and allow_creation is False
            raise Exception("Milvus vector store doesn't exist yet, you may need to run the embedding recipe.")

        builtin_function = None
        vector_field = ["vector"]
        index_params = [{"index_type": "FLAT", "metric_type": "COSINE"}]
        if hybrid:
            builtin_function = BM25BuiltInFunction()
            vector_field.append("sparse")
            index_params.append({"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25"})

        metadata_schema = {
            DKU_SECURITY_TOKENS_META: {
                "dtype": DataType.ARRAY,
                "kwargs": {"element_type": DataType.VARCHAR}
            }
        }
        for name, (_, meaning) in self.metadata_column_type_and_meaning.items():
            if meaning == "BagOfWordsMeaning":
                metadata_schema[name] = {
                    "dtype": DataType.ARRAY,
                    "kwargs": {"element_type": DataType.VARCHAR}
                }

        db = Milvus(
            embedding_function=embeddings,
            builtin_function=builtin_function,
            connection_args={"uri": self.uri},
            collection_name=self.collection_name,
            vector_field=vector_field,
            index_params=index_params,
            metadata_schema=metadata_schema
        )
        return db

    def _db_exists(self):
        return os.path.isfile(self.uri) and self.client.has_collection(COLLECTION_NAME)

    def _prepare_and_clean_metadata(self, document):
        new_meta = {}
        for key, val in document.metadata.items():
            column_type, meaning = self.metadata_column_type_and_meaning.get(key, (None, None))
            if column_type is None:
                pass
            elif column_type in ["float", "double", "tinyint", "smallint", "int", "bigint"]:
                # Keep nans because they are allowed in float columns
                val = float(val)
            elif column_type in ["boolean"]:
                if pd.isna(val):
                    val = False
            elif column_type == "dku_array" or meaning == "BagOfWordsMeaning":
                if not isinstance(val, list) and pd.isna(val):
                    val = []
            elif column_type in ["date", "datetimetz", "datetimenotz", "dateonly"] and isinstance(val, pd.Timestamp):
                # Converting to string
                val = self.process_timestamp(val)
            else:
                if pd.isna(val):
                    val = ""
                val = str(val)
            new_meta[key] = val

        # Add any missing metadata columns
        for key, (column_type, meaning) in self.metadata_column_type_and_meaning.items():
            if key not in document.metadata:
                if column_type in ["float", "double", "tinyint", "smallint", "int", "bigint"]:
                    new_meta[key] = float("nan")
                elif column_type in ["boolean"]:
                    new_meta[key] = False
                elif column_type == "dku_array" or meaning == "BagOfWordsMeaning":
                    new_meta[key] = []
                else:
                    new_meta[key] = ""

        document.metadata = new_meta
        return document

    def transform_document_before_load(self, document: "Document") -> "Document":
        """
        Hook for a vectorstore to rewrite the document at load time.
        """
        if self.document_filter is not None:
            document = self.document_filter.add_security_tokens_to_document(document)
            self.metadata_column_type_and_meaning[DKU_SECURITY_TOKENS_META] = ("dku_array", None)
        document = self._prepare_and_clean_metadata(document)
        return document

    def include_similarity_score(self, query_settings: BaseVectorStoreQuerySettings) -> bool:
        return True

    def clear_files(self, folder_path: str) -> None:
        self.client.drop_collection(COLLECTION_NAME)

    def get_file_size(self) -> int:
        size = 0

        if os.path.isfile(self.uri):
            size += os.path.getsize(self.uri)

        return size

    def get_document_metadata(self, document: Document, column: str) -> Optional[Any]:
        column_type, meaning = self.metadata_column_type_and_meaning.get(column, (None, None))
        if meaning == "BagOfWordsMeaning":
            return list(document.metadata.get(column, []))
        else:
            return document.metadata.get(column, None)


class MilvusVectorStoreDocumentFilter(VectorStoreDocumentFilter):
    """ Milvus uses Python-like operators
    See https://milvus.io/docs/basic-operators.md
    And further examples for each data type, e.g. https://milvus.io/docs/string.md
    """
    def _coerce_int(self, item):
        if type(item) == int:
            return float(item)
        return item

    def _coerce_value(self, clause: Dict) -> Union[List, str, int, float]:
        value = clause["value"]
        if type(value) == list:
            value = [self._coerce_int(item) for item in value]
        elif type(value) == int:
            return float(value)
        elif type(value) == str:
            return '"' + value + '"'
        return value

    def convert(self, clause: Dict) -> str:
        if clause["operator"] == "EQUALS":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.BagOfWords:
                return f'ARRAY_CONTAINS({clause["column"]}, {self._coerce_value(clause)})'
            return f'({clause["column"]} == {self._coerce_value(clause)})'
        elif clause["operator"] == "NOT_EQUALS":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.BagOfWords:
                raise Exception("Milvus local vector store doesn't support `not` clause on 'bag-of-words' columns.")
            return f'({clause["column"]} != {self._coerce_value(clause)})'
        elif clause["operator"] == "GREATER_THAN":
            return f'({clause["column"]} > {self._coerce_value(clause)})'
        elif clause["operator"] == "LESS_THAN":
            return f'({clause["column"]} < {self._coerce_value(clause)})'
        elif clause["operator"] == "GREATER_OR_EQUAL":
            return f'({clause["column"]} >= {self._coerce_value(clause)})'
        elif clause["operator"] == "LESS_OR_EQUAL":
            return f'({clause["column"]} <= {self._coerce_value(clause)})'
        elif clause["operator"] == "IN_ANY_OF":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.BagOfWords:
                clauses = [{"column": clause["column"], "operator": "EQUALS", "value": v} for v in clause["value"]]
                return self.convert({"operator": "OR", "clauses": clauses})
            return f'({clause["column"]} in {self._coerce_value(clause)})'
        elif clause["operator"] == "IN_NONE_OF":
            column_type = self._get_column_type(clause)
            if column_type == ColumnType.BagOfWords:
                raise Exception("Milvus local vector store doesn't support `not` clause on 'bag-of-words' columns.")
            return f'({clause["column"]} not in {self._coerce_value(clause)})'
        elif clause["operator"] == "CONTAINS":
            return f'({clause["column"]} LIKE %{self._coerce_value(clause)}%)'
        elif clause["operator"] == "AND":
            return '(' + ' and '.join([self.convert(x) for x in clause["clauses"]]) + ')'
        elif clause["operator"] == "OR":
            return '(' + ' or '.join([self.convert(x) for x in clause["clauses"]]) + ')'
        # custom operator to search a string in a list of strings, used for security tokens, not exposed to users
        elif clause["operator"] == "ARRAY_CONTAINS":
            return f'ARRAY_CONTAINS({clause["column"]}, {self._coerce_value(clause)})'
        else:
            raise ValueError("Unsupported filter operator for Milvus local vector store: %s" % clause["operator"])

    def _and_filter(self, converted_clause_a: str, converted_clause_b: str) -> str:
        return f'({converted_clause_a} and {converted_clause_b})'

    def _get_filter_field_name(self) -> str:
        return "expr"

    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add a field with the list of security tokens useful to implement add_security_filter()
        """
        if DKU_SECURITY_TOKENS_META in document.metadata:
            document.metadata[DKU_SECURITY_TOKENS_META] = json.loads(document.metadata[DKU_SECURITY_TOKENS_META])
            logging.info("Updated metadata to %s" % self._sanitize_metadata_for_print(document.metadata))

        return document

    def _get_security_token_check_clause(self, security_token: str) -> Dict:
        return {
            "operator": "ARRAY_CONTAINS",
            "column": DKU_SECURITY_TOKENS_META,
            "value": security_token,
        }
