import json
import logging
from abc import abstractmethod
from datetime import datetime, date
from enum import Enum
from typing import Dict, List, Union, Optional

from langchain_core.documents import Document

from dataiku.langchain.document_handler import DKU_MULTIMODAL_CONTENT
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META

logger = logging.getLogger(__name__)

METADATA_SAMPLE_LENGTH = 25
DKU_SECURITY_TOKEN_META_PREFIX = "DKU_SECURITY_TOKEN_"


class ColumnType(Enum):
    Date = "Date"
    Decimal = "Decimal"
    Integer = "Integer"
    String = "String"

    @staticmethod
    def from_dss_type(dss_storage_type: str):
        """ Map the storage type of a column to a simpler value type to help building typed filters
        """
        if dss_storage_type in ["date", "dateandtime", "dateonly"]:
            return ColumnType.Date
        elif dss_storage_type in ["double", "float"]:
            return ColumnType.Decimal
        elif dss_storage_type in ["bigint", "int", "smallint", "tinyint"]:
            return ColumnType.Integer
        elif dss_storage_type == "string":
            return ColumnType.String
        else:
            # no support for operators on ["geopoint", "geometry", "array", "map", "object"]
            return None


class VectorStoreDocumentFilter:
    """ A VectorStoreDocumentFilter implements DSS filters using vector store-specific methods, including document-level security tokens enforcement
    """
    def __init__(self, metadata_column_type: Dict):
        self.metadata_column_type = metadata_column_type

    def _get_column_type(self, clause: Dict) -> Optional[ColumnType]:
        dss_storage_type = self.metadata_column_type.get(clause["column"], None)
        return ColumnType.from_dss_type(dss_storage_type)

    def _unsupported_type_for_operator(self, column_type: ColumnType, operator: str) -> Exception:
        return Exception(f"Unsupported type {column_type} for filter operator {operator}")

    def _coerce_value(self, clause: Dict) -> Union[int, List[int], float, List[float], str, List[str], datetime, List[datetime], date, List[date]]:
        column_type = self._get_column_type(clause)
        if column_type == ColumnType.Integer:
            return int(clause["value"]) if type(clause["value"]) != list else [int(v) for v in clause["value"]]
        elif column_type == ColumnType.Decimal:
            return float(clause["value"]) if type(clause["value"]) != list else [float(v) for v in clause["value"]]
        else:
            return clause["value"]

    @abstractmethod
    def convert(self, simple_filter: Dict) -> Dict:
        """ Convert a SimpleFilter to a vector store-specific filter object
        """
        raise NotImplementedError()

    @abstractmethod
    def _and_filter(self, converted_clause_a: Dict, converted_clause_b: Dict) -> Dict:
        """ Return the vector store-specific filter that ANDs the two vector store-specific filter clauses
        """
        raise NotImplementedError()

    @abstractmethod
    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add fields used to implement security tokens based document filtering
        """
        raise NotImplementedError()

    @abstractmethod
    def _get_security_token_check_clause(self, token: str) -> Dict:
        """ Return a simple filter clause object that filters document to only return those that hold the given token
        """
        raise NotImplementedError()

    def add_security_filter(self, search_kwargs: Dict, caller_security_tokens: List) -> None:
        """ Add filters to a search query to enforce security tokens
        """
        security_clauses = []
        for caller_security_token in caller_security_tokens:
            security_clause = self._get_security_token_check_clause(caller_security_token)
            security_clauses.append(security_clause)

        if len(security_clauses) <= 0:
            raise Exception("No caller security tokens provided for vector store security filter: impossible to build the filter")
        elif len(security_clauses) == 1:
            security_clauses_top = security_clauses[0]
        else:
            security_clauses_top = {
                "operator": "OR",
                "clauses": security_clauses,
            }

        self.add_filter(search_kwargs, security_clauses_top)

    @abstractmethod
    def _get_filter_field_name(self) -> str:
        """ Return the name of the field used to filter documents
        """
        raise NotImplementedError()

    def add_filter(self, search_kwargs: Dict, simple_filter: Dict) -> None:
        """ Add filters to a search query from a simple_filter set in the UI
        """
        converted_filter = self.convert(simple_filter)

        filter_field = self._get_filter_field_name()
        if filter_field not in search_kwargs:
           search_kwargs[filter_field] = converted_filter
        else:
           search_kwargs[filter_field] = self._and_filter(search_kwargs[filter_field], converted_filter)

    @staticmethod
    def _sanitize_metadata_for_print(metadata: Dict) -> Dict:
        """
        :return: a copy of the metadata dict with shrinked multimodal content if required, the original dict otherwise
        :rtype: Dict
        """
        # Might be best to also shrink other fields of the dict ?
        if DKU_MULTIMODAL_CONTENT in metadata and len(metadata[DKU_MULTIMODAL_CONTENT]) > METADATA_SAMPLE_LENGTH:
            sanitized_meta = dict(metadata)  # make a copy not to modify the original document
            # DKU_MULTIMODAL_CONTENT can contains very long str, shrink it to a few chars to avoid flooding
            sanitized_meta[DKU_MULTIMODAL_CONTENT] = sanitized_meta[DKU_MULTIMODAL_CONTENT][0: METADATA_SAMPLE_LENGTH] + "..."
            return sanitized_meta
        return metadata


class MongoDBLikeVectorStoreDocumentFilter(VectorStoreDocumentFilter):
    """ Generic document filter using MongoDB's syntax https://www.mongodb.com/docs/manual/reference/operator/query/

    Can be used with:
    - Chroma: https://docs.trychroma.com/docs/querying-collections/metadata-filtering
    - Pinecone: https://docs.pinecone.io/guides/data/understanding-metadata#metadata-query-language
    - FAISS: https://python.langchain.com/docs/integrations/vectorstores/faiss/#similarity-search-with-filtering
    """
    NEQ_OPERATOR = "$ne"  # FAISS uses "$neq"

    def convert(self, clause: Dict) -> Dict:
        if clause["operator"] == "EQUALS":
            return { clause["column"] : {"$eq": self._coerce_value(clause)} }
        elif clause["operator"] == "NOT_EQUALS":
            return { clause["column"] : {self.NEQ_OPERATOR: self._coerce_value(clause)} }
        elif clause["operator"] == "GREATER_THAN":
            return { clause["column"] : {"$gt": self._coerce_value(clause)} }
        elif clause["operator"] == "LESS_THAN":
            return { clause["column"] : {"$lt": self._coerce_value(clause)} }
        elif clause["operator"] == "GREATER_OR_EQUAL":
            return { clause["column"] : {"$gte": self._coerce_value(clause)} }
        elif clause["operator"] == "LESS_OR_EQUAL":
            return { clause["column"] : {"$lte": self._coerce_value(clause)} }
        elif clause["operator"] == "IN_ANY_OF":
            return { clause["column"] : {"$in": self._coerce_value(clause)} }
        elif clause["operator"] == "IN_NONE_OF":
            return { clause["column"] : {"$nin": self._coerce_value(clause)} }
        elif clause["operator"] == "AND":
            return { "$and" : [self.convert(x) for x in clause["clauses"]] }
        elif clause["operator"] == "OR":
            return { "$or" : [self.convert(x) for x in clause["clauses"]] }
        else:
            raise Exception("Unsupported filter operator: %s" % clause["operator"])

    def _and_filter(self, converted_clause_a: Dict, converted_clause_b: Dict) -> Dict:
        return { "$and" : [converted_clause_a, converted_clause_b]}

    def _get_filter_field_name(self) -> str:
        return "filter"

    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add a boolean field per token
        """
        if DKU_SECURITY_TOKENS_META in document.metadata:
            try:
                security_tokens = json.loads(document.metadata[DKU_SECURITY_TOKENS_META])
            except (ValueError, TypeError) as e:
                raise ValueError("Invalid format for security tokens. Expected a JSON Array: {}, {}".format(e, document.metadata[DKU_SECURITY_TOKENS_META]))
            for token in security_tokens:
                document.metadata[self._get_security_token_with_prefix(token)] = "true"
            logger.info("Updated metadata to %s" % self._sanitize_metadata_for_print(document.metadata))
        return document

    def _get_security_token_check_clause(self, token: str) -> Dict:
        return {
            "operator": "EQUALS",
            "column": self._get_security_token_with_prefix(token),
            "value": "true",
        }

    @staticmethod
    def _get_security_token_with_prefix(token: str) -> str:
        return DKU_SECURITY_TOKEN_META_PREFIX + token
