import json
import logging
from abc import abstractmethod
from datetime import datetime, date
from enum import Enum
from typing import Dict, List, Union, Optional, Any, Callable

from langchain_core.documents import Document

from dataiku.base.utils import package_is_at_least_no_import
from dataiku.langchain.document_handler import DKU_MULTIMODAL_CONTENT
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META

logger = logging.getLogger(__name__)

METADATA_SAMPLE_LENGTH = 25
DKU_SECURITY_TOKEN_META_PREFIX = "DKU_SECURITY_TOKEN_"


class ColumnType(Enum):
    Date = "Date"
    Decimal = "Decimal"
    Integer = "Integer"
    String = "String"
    BagOfWords = "BagOfWords" # Keywords embedding

    @staticmethod
    def from_dss_type(dss_storage_type: str, meaning: Optional[str]):
        """ Map the storage type and meaning of column to a simpler value type to help building typed filters
        """
        if meaning == "BagOfWordsMeaning":
            return ColumnType.BagOfWords
        elif dss_storage_type in ["date", "dateandtime", "dateonly"]:
            return ColumnType.Date
        elif dss_storage_type in ["double", "float"]:
            return ColumnType.Decimal
        elif dss_storage_type in ["bigint", "int", "smallint", "tinyint"]:
            return ColumnType.Integer
        elif dss_storage_type == "string":
            return ColumnType.String
        else:
            # no support for operators on ["geopoint", "geometry", "map", "object"]
            return None


class VectorStoreDocumentFilter:
    """ A VectorStoreDocumentFilter implements DSS filters using vector store-specific methods, including document-level security tokens enforcement
    """
    def __init__(self, metadata_column_type_and_meaning: Dict):
        self.metadata_column_type_and_meaning = metadata_column_type_and_meaning

    def _get_column_type(self, clause: Dict) -> Optional[ColumnType]:
        dss_storage_type, meaning  = self.metadata_column_type_and_meaning.get(clause.get("column"), (None, None))
        return ColumnType.from_dss_type(dss_storage_type, meaning)

    def _unsupported_type_for_operator(self, column_type: ColumnType, operator: str) -> Exception:
        return Exception(f"Unsupported type {column_type} for filter operator {operator}")

    def _coerce_value(self, clause: Dict) -> Union[int, List[int], float, List[float], str, List[str], datetime, List[datetime], date, List[date]]:
        column_type = self._get_column_type(clause)
        if column_type == ColumnType.Integer:
            return int(clause["value"]) if type(clause["value"]) != list else [int(v) for v in clause["value"]]
        elif column_type == ColumnType.Decimal:
            return float(clause["value"]) if type(clause["value"]) != list else [float(v) for v in clause["value"]]
        else:
            return clause["value"]

    @abstractmethod
    def convert(self, simple_filter: Dict) -> Dict:
        """ Convert a SimpleFilter to a vector store-specific filter object
        """
        raise NotImplementedError()

    @abstractmethod
    def _and_filter(self, converted_clause_a: Dict, converted_clause_b: Dict) -> Dict:
        """ Return the vector store-specific filter that ANDs the two vector store-specific filter clauses
        """
        raise NotImplementedError()

    @abstractmethod
    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add fields used to implement security tokens based document filtering
        """
        raise NotImplementedError()

    @abstractmethod
    def _get_security_token_check_clause(self, token: str) -> Dict:
        """ Return a simple filter clause object that filters document to only return those that hold the given token
        """
        raise NotImplementedError()

    def add_security_filter(self, search_kwargs: Dict, caller_security_tokens: List) -> None:
        """ Add filters to a search query to enforce security tokens
        """
        security_clauses = []
        for caller_security_token in caller_security_tokens:
            security_clause = self._get_security_token_check_clause(caller_security_token)
            security_clauses.append(security_clause)

        if len(security_clauses) <= 0:
            raise Exception("No caller security tokens provided for vector store security filter: impossible to build the filter")
        elif len(security_clauses) == 1:
            security_clauses_top = security_clauses[0]
        else:
            security_clauses_top = {
                "operator": "OR",
                "clauses": security_clauses,
            }

        self.add_filter(search_kwargs, security_clauses_top)

    @abstractmethod
    def _get_filter_field_name(self) -> str:
        """ Return the name of the field used to filter documents
        """
        raise NotImplementedError()

    def add_filter(self, search_kwargs: Dict, simple_filter: Dict) -> None:
        """ Add filters to a search query from a simple_filter set in the UI
        """
        converted_filter = self.convert(simple_filter)

        filter_field = self._get_filter_field_name()
        if filter_field not in search_kwargs:
           search_kwargs[filter_field] = converted_filter
        else:
           search_kwargs[filter_field] = self._and_filter(search_kwargs[filter_field], converted_filter)

    @staticmethod
    def _sanitize_metadata_for_print(metadata: Dict) -> Dict:
        """
        :return: a copy of the metadata dict with shrinked multimodal content if required, the original dict otherwise
        :rtype: Dict
        """
        # Might be best to also shrink other fields of the dict ?
        if DKU_MULTIMODAL_CONTENT in metadata and len(metadata[DKU_MULTIMODAL_CONTENT]) > METADATA_SAMPLE_LENGTH:
            sanitized_meta = dict(metadata)  # make a copy not to modify the original document
            # DKU_MULTIMODAL_CONTENT can contains very long str, shrink it to a few chars to avoid flooding
            sanitized_meta[DKU_MULTIMODAL_CONTENT] = sanitized_meta[DKU_MULTIMODAL_CONTENT][0: METADATA_SAMPLE_LENGTH] + "..."
            return sanitized_meta
        return metadata


class MongoDBLikeVectorStoreDocumentFilter(VectorStoreDocumentFilter):
    """ Generic document filter using MongoDB's syntax https://www.mongodb.com/docs/manual/reference/operator/query/

    Can be used with:
    - Chroma: https://docs.trychroma.com/docs/querying-collections/metadata-filtering
    - Pinecone: https://docs.pinecone.io/guides/data/understanding-metadata#metadata-query-language
    - FAISS: https://python.langchain.com/docs/integrations/vectorstores/faiss/#similarity-search-with-filtering
    """
    NEQ_OPERATOR = "$ne"  # FAISS uses "$neq"

    def convert(self, clause: Dict) -> Dict:
        if clause["operator"] == "EQUALS":
            return { clause["column"] : {"$eq": self._coerce_value(clause)} }
        elif clause["operator"] == "NOT_EQUALS":
            return { clause["column"] : {self.NEQ_OPERATOR: self._coerce_value(clause)} }
        elif clause["operator"] == "GREATER_THAN":
            return { clause["column"] : {"$gt": self._coerce_value(clause)} }
        elif clause["operator"] == "LESS_THAN":
            return { clause["column"] : {"$lt": self._coerce_value(clause)} }
        elif clause["operator"] == "GREATER_OR_EQUAL":
            return { clause["column"] : {"$gte": self._coerce_value(clause)} }
        elif clause["operator"] == "LESS_OR_EQUAL":
            return { clause["column"] : {"$lte": self._coerce_value(clause)} }
        elif clause["operator"] == "IN_ANY_OF":
            return { clause["column"] : {"$in": self._coerce_value(clause)} }
        elif clause["operator"] == "IN_NONE_OF":
            return { clause["column"] : {"$nin": self._coerce_value(clause)} }
        elif clause["operator"] == "AND":
            return { "$and" : [self.convert(x) for x in clause["clauses"]] }
        elif clause["operator"] == "OR":
            return { "$or" : [self.convert(x) for x in clause["clauses"]] }
        else:
            raise Exception("Unsupported filter operator: %s" % clause["operator"])

    def _and_filter(self, converted_clause_a: Dict, converted_clause_b: Dict) -> Dict:
        return { "$and" : [converted_clause_a, converted_clause_b]}

    def _get_filter_field_name(self) -> str:
        return "filter"

    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add a boolean field per token
        """
        if DKU_SECURITY_TOKENS_META in document.metadata:
            try:
                security_tokens = json.loads(document.metadata[DKU_SECURITY_TOKENS_META])
            except (ValueError, TypeError) as e:
                raise ValueError("Invalid format for security tokens. Expected a JSON Array: {}, {}".format(e, document.metadata[DKU_SECURITY_TOKENS_META]))
            for token in security_tokens:
                document.metadata[self._get_security_token_with_prefix(token)] = "true"
            logger.info("Updated metadata to %s" % self._sanitize_metadata_for_print(document.metadata))
        return document

    def _get_security_token_check_clause(self, token: str) -> Dict:
        return {
            "operator": "EQUALS",
            "column": self._get_security_token_with_prefix(token),
            "value": "true",
        }

    @staticmethod
    def _get_security_token_with_prefix(token: str) -> str:
        return DKU_SECURITY_TOKEN_META_PREFIX + token


class ChromaDBVectorStoreDocumentFilter(MongoDBLikeVectorStoreDocumentFilter):
    def convert(self, clause: Dict) -> Dict:
        column_type = self._get_column_type(clause)
        if column_type is None or column_type != ColumnType.BagOfWords:
            return super().convert(clause)

        # Bag Of Words case
        if clause["operator"] == "EQUALS":
            new_clause = {
                "operator": "EQUALS",
                "column": f"DKU_BOW_{clause['column']}__{clause['value']}",
                "value": True
            }
            return super().convert(new_clause)
        elif clause["operator"] == "NOT_EQUALS":
            # If the column is there it means the category is there.
            #   - We don't want the document -> value != "true"
            # If the column is not there:
            #   - "NOT_EQUALS" on a missing metadata column will return true
            # https://docs.trychroma.com/docs/overview/migration#v0512---october-8-2024
            #
            if not package_is_at_least_no_import("chromadb", "0.5.3"):
                raise Exception("$ne filter is not supported for Bag Of Words filtering in ChromaDB < 0.5.3.")
            new_clause = {
                "operator": "NOT_EQUALS",
                "column": f"DKU_BOW_{clause['column']}__{clause['value']}",
                "value": True
            }
            return super().convert(new_clause)
        elif clause["operator"] == "IN_ANY_OF":
            list_of_strings = self._coerce_value(clause)
            or_clause = {
                "operator": "OR",
                "clauses": [
                    {"operator": "EQUALS", "column": f"DKU_BOW_{clause['column']}__{v}", "value": True} for v in list_of_strings
                ]
            }
            return super().convert(or_clause)
        elif clause["operator"] == "IN_NONE_OF":
            # https://docs.trychroma.com/docs/overview/migration#v0512---october-8-2024
            #
            if not package_is_at_least_no_import("chromadb", "0.5.3"):
                raise Exception("$ne filter is not supported for Bag Of Words filtering in ChromaDB < 0.5.3.")
            list_of_strings = self._coerce_value(clause)
            and_clause = {
                "operator": "AND",
                "clauses": [
                    {"operator": "NOT_EQUALS", "column": f"DKU_BOW_{clause['column']}__{v}", "value": True} for v in list_of_strings
                ]
            }
            return super().convert(and_clause)
        else:
            raise Exception("Unsupported filter operator: %s" % clause["operator"])


class FAISSVectorStoreDocumentFilter(MongoDBLikeVectorStoreDocumentFilter):
    """
    Implementation of filtering for FAISS vector stores using either dictionary-based or callable-based filters.
    """

    def _evaluate_standard_filter_in_memory(self, filter_dict: Dict, metadata: Dict) -> bool:
        """
        Manually evaluates a standard MongoDB-style dictionary filter against a document's metadata.
        This is used when we need to combine standard filters with custom callable filters.
        """
        for key, value in filter_dict.items():
            if key == "$and":
                if not all(self._evaluate_standard_filter_in_memory(sub, metadata) for sub in value):
                    return False
            elif key == "$or":
                if not any(self._evaluate_standard_filter_in_memory(sub, metadata) for sub in value):
                    return False
            else:
                # key is a column name
                meta_val = metadata.get(key)
                if isinstance(value, dict):
                    for op, op_val in value.items():
                        if op == "$eq":
                            if meta_val != op_val: return False
                        elif op == "$ne" or op == "$neq":
                            if meta_val == op_val: return False
                        elif op == "$gt":
                            if not (meta_val > op_val): return False
                        elif op == "$gte":
                            if not (meta_val >= op_val): return False
                        elif op == "$lt":
                            if not (meta_val < op_val): return False
                        elif op == "$lte":
                            if not (meta_val <= op_val): return False
                        elif op == "$in":
                            if meta_val not in op_val: return False
                        elif op == "$nin":
                            if meta_val in op_val: return False
                        else:
                            raise Exception(f"Unsupported op in evaluator: {op}")
                else:
                    raise ValueError(f"Expected a dictionary for filter value: {value}, got {type(value)}")
        return True

    def _apply_standard_or_custom_filter(self, filter_obj: Union[Dict, Callable], metadata: Dict) -> bool:
        if callable(filter_obj):
            return filter_obj(metadata)
        else:
            return self._evaluate_standard_filter_in_memory(filter_obj, metadata)

    def convert(self, clause: Dict) -> Union[Dict, Callable]:
        # 1. Handle logic operators (AND/OR) which might mix dicts and callables
        if clause["operator"] == "AND":
            converted_clauses = [self.convert(c) for c in clause["clauses"]]
            if any(callable(c) for c in converted_clauses):
                return lambda metadata: all(self._apply_standard_or_custom_filter(c, metadata) for c in converted_clauses)
            else:
                return { "$and" : converted_clauses }
        elif clause["operator"] == "OR":
            converted_clauses = [self.convert(c) for c in clause["clauses"]]
            if any(callable(c) for c in converted_clauses):
                return lambda metadata: any(self._apply_standard_or_custom_filter(c, metadata) for c in converted_clauses)
            else:
                return { "$or" : converted_clauses }

        column_type = self._get_column_type(clause)

        # 2. Handle BagOfWords columns (always return a callable)
        if column_type == ColumnType.BagOfWords:
            if clause["operator"] == "EQUALS":
                return lambda metadata: clause["value"] in metadata.get(clause["column"], [])
            elif clause["operator"] == "NOT_EQUALS":
                return lambda metadata: clause["value"] not in metadata.get(clause["column"], [])
            elif clause["operator"] == "IN_ANY_OF":
                list_of_strings = self._coerce_value(clause)
                return lambda metadata: bool(set(metadata.get(clause["column"], [])) & set(list_of_strings))
            elif clause["operator"] == "IN_NONE_OF":
                list_of_strings = self._coerce_value(clause)
                return lambda metadata: set(metadata.get(clause["column"], [])).isdisjoint(set(list_of_strings))
            else:
                raise Exception("Unsupported filter operator: %s" % clause["operator"])

        # 3. Handle standard columns (return a dict)
        return super().convert(clause)

    def _and_filter(self, converted_clause_a: Any, converted_clause_b: Any) -> Any:
        if callable(converted_clause_a) or callable(converted_clause_b):
            return lambda metadata: self._apply_standard_or_custom_filter(converted_clause_a, metadata) and self._apply_standard_or_custom_filter(converted_clause_b, metadata)
        else:
            return { "$and" : [converted_clause_a, converted_clause_b]}
