import json
import logging
import os
import shutil
from typing import Any, Dict

import portalocker
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchAny, MatchExcept, MatchValue, Range
from qdrant_client.http import models as rest_models
from qdrant_client.http.models import VectorParams, Distance
from qdrant_client.local.local_collection import LocalCollection
from qdrant_client.local.qdrant_local import QdrantLocal, META_INFO_FILENAME

from dataiku.core.vector_stores.dku_vector_store import DkuLocalVectorStore, logger
from dataiku.core.vector_stores.vector_store_document_filter import VectorStoreDocumentFilter, DKU_SECURITY_TOKEN_META_PREFIX, ColumnType
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META
from dataiku.llm.types import RetrievableKnowledge
from dataikuapi.dss.langchain import DKUEmbeddings

try:
    from langchain_qdrant import QdrantVectorStore as LangchainQdrant  # type: ignore
except ImportError:
    from langchain_community.vectorstores import Qdrant as LangchainQdrant  # type: ignore

# Monkeypatch of _load() method in QdrantLocal
# Uses read-only permission when opening the lock file instead of r+ (read and write)
# Enables readers to use the KB in UIF setups
# The lock file should be created only by dss unimpersonated code or any user with the write permission
def load_qdrant_monkey_patch(local_client: QdrantLocal) -> None:
    if not local_client.persistent:
        return
    meta_path = os.path.join(local_client.location, META_INFO_FILENAME)
    if not os.path.exists(meta_path):
        os.makedirs(local_client.location, exist_ok=True)
        with open(meta_path, "w") as f:
            f.write(json.dumps({"collections": {}, "aliases": {}}))
    else:
        with open(meta_path, "r") as f:
            meta = json.load(f)
            for collection_name, config_json in meta["collections"].items():
                config = rest_models.CreateCollection(**config_json)
                collection_path = local_client._collection_path(collection_name)
                local_client.collections[collection_name] = LocalCollection(
                    config,
                    collection_path,
                    force_disable_check_same_thread=local_client.force_disable_check_same_thread,
                )
            local_client.aliases = meta["aliases"]

    lock_file_path = os.path.join(local_client.location, ".lock")
    if not os.path.exists(lock_file_path):
        os.makedirs(local_client.location, exist_ok=True)
        with open(lock_file_path, "w") as f:
            f.write("tmp lock file")
    local_client._flock_file = open(lock_file_path, "r")  # Important: keep it read only
    try:
        portalocker.lock(
            local_client._flock_file,
            portalocker.LockFlags.EXCLUSIVE | portalocker.LockFlags.NON_BLOCKING,
        )
    except portalocker.exceptions.LockException:
        raise RuntimeError(
            f"Storage folder {local_client.location} is already accessed by another instance of Qdrant client."
            f" If you require concurrent access, use Qdrant server instead."
        )


QdrantLocal._load = load_qdrant_monkey_patch  # type: ignore


class QDrantLocalVectorStore(DkuLocalVectorStore):

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str):
        super(QDrantLocalVectorStore, self).__init__(kb, exec_folder, collection_name="knowledge_bank")
        self.document_filter = QdrantLocalVectorStoreDocumentFilter(self.metadata_column_type)

    def get_db(self, embeddings: DKUEmbeddings, allow_creation: bool = False, **kwargs: Any) -> VectorStore:

        client = QdrantClient(path=self.exec_folder)
        if not client.collection_exists(self.collection_name):  # todo should check for allow_creation first
            embedding_size = self.get_vector_size()
            client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=embedding_size, distance=Distance.COSINE),
            )
        return LangchainQdrant(client, self.collection_name, embeddings)

    def clear_files(self, folder_path: str) -> None:
        if os.path.isfile(os.path.join(folder_path, "meta.json")):
            os.remove(os.path.join(folder_path, "meta.json"))

        if os.path.isdir(os.path.join(folder_path, "collection")):
            shutil.rmtree(os.path.join(folder_path, "collection"))

        if os.path.isfile(os.path.join(folder_path, ".lock")):
            os.remove(os.path.join(folder_path, ".lock"))

        logger.info("Cleared QDrant vector store files at {}".format(folder_path))

    def get_file_size(self) -> int:
        size = 0
        filepaths = [os.path.join(self.exec_folder, "collection"), os.path.join(self.exec_folder, "meta.json"), os.path.join(self.exec_folder, ".lock")]

        for filepath in filepaths:
            if os.path.isfile(filepath):
                size += os.path.getsize(filepath)
            elif os.path.isdir(filepath):
                for dir_name, _, dir_filepaths in os.walk(filepath):
                    for dir_filepath in dir_filepaths:
                        full_filepath = os.path.join(dir_name, dir_filepath)
                        if os.path.isfile(full_filepath):
                            size += os.path.getsize(full_filepath)

        return size

class QdrantLocalVectorStoreDocumentFilter(VectorStoreDocumentFilter):
    """ Qdrant document filter

    See https://qdrant.tech/documentation/concepts/filtering/#filtering
    and https://python.langchain.com/docs/integrations/vectorstores/qdrant/#metadata-filtering
    """
    def convert(self, clause: Dict) -> Filter:
        if clause["operator"] == "EQUALS":
            column_type = self._get_column_type(clause)
            if column_type in [ColumnType.String, ColumnType.Integer] or clause["column"].startswith(DKU_SECURITY_TOKEN_META_PREFIX):
                return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', match=MatchValue(value=self._coerce_value(clause)))])
            elif column_type == ColumnType.Decimal:
                # see https://qdrant.tech/articles/vector-search-filtering/#filtering-with-float-point-decimal-numbers
                value = self._coerce_value(clause)
                return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(gte=value, lte=value))])
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "NOT_EQUALS":
            column_type = self._get_column_type(clause)
            if column_type in [ColumnType.String, ColumnType.Integer]:
                return Filter(must_not=[FieldCondition(key=f'metadata.{clause["column"]}', match=MatchValue(value=self._coerce_value(clause)))])
            elif column_type == ColumnType.Decimal:
                value = self._coerce_value(clause)
                return Filter(must_not=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(gte=value, lte=value))])
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "GREATER_THAN":
            return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(gt=self._coerce_value(clause)))])
        elif clause["operator"] == "LESS_THAN":
            return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(lt=self._coerce_value(clause)))])
        elif clause["operator"] == "GREATER_OR_EQUAL":
            return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(gte=self._coerce_value(clause)))])
        elif clause["operator"] == "LESS_OR_EQUAL":
            return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(lte=self._coerce_value(clause)))])
        elif clause["operator"] == "IN_ANY_OF":
            column_type = self._get_column_type(clause)
            if column_type in [ColumnType.String, ColumnType.Integer]:
                return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', match=MatchAny(any=self._coerce_value(clause)))])
            elif column_type == ColumnType.Decimal:
                return Filter(should=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(gte=value, lte=value)) for value in self._coerce_value(clause)])
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "IN_NONE_OF":
            column_type = self._get_column_type(clause)
            if column_type in [ColumnType.String, ColumnType.Integer]:
                return Filter(must=[FieldCondition(key=f'metadata.{clause["column"]}', match=MatchExcept(**{"except": self._coerce_value(clause)}))])
            elif column_type == ColumnType.Decimal:
                return Filter(must_not=[FieldCondition(key=f'metadata.{clause["column"]}', range=Range(gte=value, lte=value)) for value in self._coerce_value(clause)])
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "AND":
            return Filter(must=[self.convert(x) for x in clause["clauses"]])
        elif clause["operator"] == "OR":
            return Filter(should=[self.convert(x) for x in clause["clauses"]])
        else:
            raise Exception("Unsupported filter operator for QdrantLocal: %s" % clause["operator"])

    def _and_filter(self, converted_clause_a: Dict, converted_clause_b: Dict) -> Dict:
        return Filter(must=[converted_clause_a, converted_clause_b])

    def _get_filter_field_name(self) -> str:
        return "filter"

    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add a boolean field per token
        """
        if DKU_SECURITY_TOKENS_META in document.metadata:
            security_tokens = json.loads(document.metadata[DKU_SECURITY_TOKENS_META])
            for token in security_tokens:
                document.metadata[self._get_security_token_with_prefix(token)] = "true"
            logging.info("Updated metadata to %s" % self._sanitize_metadata_for_print(document.metadata))
        return document

    def _get_security_token_check_clause(self, token: str) -> Dict:
        return {
            "operator": "EQUALS",
            "column": self._get_security_token_with_prefix(token),
            "value": "true",
        }

    @staticmethod
    def _get_security_token_with_prefix(token: str) -> str:
        return DKU_SECURITY_TOKEN_META_PREFIX + token
