import enum
import json
import logging
import random
import time
from collections.abc import Callable
from typing import Dict, Optional, Any, List, Union
import pandas as pd

from azure.core import credentials
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import SearchableField, SearchField, SearchFieldDataType, SimpleField
from langchain_community.vectorstores import azuresearch
from langchain_core.vectorstores import VectorStore
from langchain_core.documents import Document

from dataiku.core.vector_stores.dku_vector_store import DkuRemoteVectorStore, logger
from dataiku.llm.types import RetrievableKnowledge, BaseVectorStoreQuerySettings
from dataikuapi.dss.admin import DSSConnectionInfo
from dataikuapi.dss.langchain import DKUEmbeddings
from dataiku.core.vector_stores.vector_store_document_filter import VectorStoreDocumentFilter, ColumnType
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META


logging.getLogger("azure").setLevel(logging.WARNING)


SEMANTIC_CONFIGURATION_NAME = "DKU__SEMANTIC_CONFIG"

class AzureAISearchType(enum.Enum):
    SIMILARITY = "similarity"
    SIMILARITY_THRESHOLD = "similarity_score_threshold"
    HYBRID = "hybrid"
    SEMANTIC_HYBRID = "semantic_hybrid"


class AzureAISearchVectorStore(DkuRemoteVectorStore):

    AZUREAISEARCH_RESOURCE_URL_FORMAT = "https://{}.search.windows.net"
    DEFAULT_OAUTH_SCOPE = "https://search.azure.com/.default"

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str, connection_info_retriever: Callable[[str], DSSConnectionInfo]):
        self.index_client: Optional[SearchIndexClient] = None
        self.azure_search_key: Optional[str] = None
        self.azure_ai_search_url: Optional[str] = None
        self.azure_credential: Optional[Union[credentials.TokenCredential, credentials.AzureKeyCredential]] = None # to connect to Azure SearchIndexClient on our side
        self.auth_kwargs: Dict = {}  # for langchain (also connecting to the SearchIndexClient but internally)
        super().__init__(kb, exec_folder, connection_info_retriever)
        self.document_filter = AzureAISearchVectorStoreDocumentFilter(self.metadata_column_type)

    def init_connection(self) -> None:

        class AzureBearerTokenCredential(credentials.TokenCredential):
            def __init__(self, access_token):
                self.az_access_token = credentials.AccessToken(access_token, int(time.time()) + 3600)

            def get_token(self, *args, **kwargs):
                return self.az_access_token

        connection_infos = self.connection_info_retriever(self.connection_name)
        connection_params = connection_infos.get_params()
        self.azure_ai_search_url = self.build_azure_ai_search_resource_url(connection_params['resourceName'])

        if connection_params["authType"] == "OAUTH2_APP":
            logger.info("Initiating the connection to Azure AI Search with Oauth ({})".format(
                connection_infos.get_credential_mode())
            )
            # Global & Per user oauth: retrieve the credentials resolved in the Java part
            resolved_access_token = connection_infos.get_oauth2_credential()['accessToken']
            self.azure_credential = AzureBearerTokenCredential(resolved_access_token)

            self.auth_kwargs = {
                "azure_ad_access_token": resolved_access_token,
                "azure_search_key": None  # required: as of today azure_search_key isn't optional nor has default value in langchain api
            }

        elif connection_params["authType"] == "API_KEY":
            logger.info("Initiating the connection to Azure AI Search with Api Key")

            self.azure_search_key = connection_params['apiKey']
            self.azure_credential = credentials.AzureKeyCredential(self.azure_search_key)

            self.auth_kwargs = {"azure_search_key": self.azure_search_key}

        else:
            raise Exception("Unsupported authorization type {} for Azure AI Search connection".format(connection_params["authType"]))

        # Create a search index client once to be able to clear the index when required
        self.index_client = SearchIndexClient(endpoint=self.azure_ai_search_url, credential=self.azure_credential)

    @classmethod
    def build_azure_ai_search_resource_url(cls, name_or_url: str) -> str:
        # (http unsupported by the api)
        return name_or_url if name_or_url.startswith("https://") else cls.AZUREAISEARCH_RESOURCE_URL_FORMAT.format(name_or_url)

    @staticmethod
    def _get_search_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        # Neither mmr nor k kwargs are supported for azure AI retriever (search kwargs)
        return {**DkuRemoteVectorStore._get_similarity_threshold_args(query_settings)}

    @staticmethod
    def _get_retriever_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        # For azure AI the k is defined on the retriever kwargs directly
        return {"k": query_settings.get("maxDocuments", 4)}

    @staticmethod
    def _get_search_type(query_settings: BaseVectorStoreQuerySettings) -> str:
        hybrid_search = query_settings.get("searchType") == "HYBRID"
        advanced_reranking = query_settings.get("useAdvancedReranking", False)
        similarity_threshold = query_settings.get("searchType") == "SIMILARITY_THRESHOLD"
        # Assuming that if hybrid_search isn't enabled then advanced_reranking is irrelevant
        if hybrid_search:
            if advanced_reranking:
                return AzureAISearchType.SEMANTIC_HYBRID.value
            else:
                return AzureAISearchType.HYBRID.value
        else:
            if similarity_threshold:
                return AzureAISearchType.SIMILARITY_THRESHOLD.value
            else:
                return AzureAISearchType.SIMILARITY.value

    def set_index_name(self, index_name: str) -> None:
        self.index_name = index_name.lower()

    def clear_index(self) -> None:
        self.index_client.delete_index(self.index_name)

    def get_db(self, embeddings: DKUEmbeddings, allow_creation: bool = False, **kwargs: Any) -> VectorStore:
        # todo should check if index already exist to raise an error if allow_creation=false first (langchain always create it by default if unfound)

        fields_from_metadata = [
            SimpleField(
                name=column_name,
                type=get_azure_data_type(column_type, column_name),
                filterable=True,
                searchable=False,  # don't want these fields used in hybrid search
            )
            for column_name, column_type in self.metadata_column_type.items()
        ]

        # field used to enforce document-level security tokens
        security_tokens_field = [
            SimpleField(
                name=DKU_SECURITY_TOKENS_META,
                type=SearchFieldDataType.Collection(SearchFieldDataType.String),
                filterable=True,
                searchable=False,
            )
        ]

        db = azuresearch.AzureSearch(
            azure_search_endpoint=self.azure_ai_search_url,
            index_name=self.index_name,
            embedding_function=embeddings,
            semantic_configuration_name=SEMANTIC_CONFIGURATION_NAME,
            fields=self.get_required_fields() + fields_from_metadata + security_tokens_field,
            **self.auth_kwargs
        )
        return db

    def get_required_fields(self) -> List[SearchField]:
        # These fields are required by Azure AI search
        # From the AzureSearch langchain wrapper, see the definition of `default_fields` in `AzureSearch.__init__`
        # https://python.langchain.com/api_reference/_modules/langchain_community/vectorstores/azuresearch.html
        return [
            SimpleField(
                name="id",
                type=SearchFieldDataType.String,
                key=True,
                filterable=True,
            ),
            SearchableField(
                name="content",
                type=SearchFieldDataType.String,
                searchable=True,
            ),
            SearchField(
                name="content_vector",
                type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
                searchable=True,
                vector_search_dimensions=self.get_vector_size(),
                vector_search_profile_name="myHnswProfile",  # a profile with this name is created by the AzureSearch langchain wrapper
            ),
            SearchableField(
                name="metadata",
                type=SearchFieldDataType.String,
                searchable=True,
            )
        ]

    def process_timestamp(self, value: pd.Timestamp) -> str:
        """ Convert a pandas timestamp to a string, used for azure timestamp format
        """
        return value.strftime("%Y-%m-%dT%H:%M:%S.%fZ")


class AzureAISearchVectorStoreDocumentFilter(VectorStoreDocumentFilter):
    """ Azure AI Search uses the OData language for filters

    See https://learn.microsoft.com/en-us/azure/search/query-odata-filter-orderby-syntax
    and https://docs.oasis-open.org/odata/odata-json-format/v4.01/odata-json-format-v4.01.html
    """

    def escape_string(self, value: str) -> str:
        """ Make sure string values can't interfere with the filter query, see https://learn.microsoft.com/en-us/azure/search/query-odata-filter-orderby-syntax#escaping-special-characters-in-string-constants
        """
        return value.replace("'", "''")

    def _coerce_value(self, clause: Dict) -> Union[List[str], str]:
        value = super()._coerce_value(clause)
        if type(value) is str:
            value = "'" + self.escape_string(value) + "'"
        elif type(value) is list:
            value = [self.escape_string(v) if type(v) is str else v for v in value]
        return value

    def choose_separator(self, list_of_strings: List[str], num_tries: int=100) -> str:
        """ Generate a string separator that is not part of the values we search for
        """
        if num_tries == 0:
            raise RuntimeError("failed to generate a separator that is not part of the search query")
        separator = ''.join(random.choice('-_~+=#') for _ in range(3))
        if separator in ''.join(list_of_strings):
            return self.choose_separator(list_of_strings, num_tries-1)
        else:
            return separator

    def convert(self, clause: Dict) -> str:
        if clause["operator"] == "EQUALS":
            return f'({clause["column"]} eq {self._coerce_value(clause)})'
        elif clause["operator"] == "NOT_EQUALS":
            return f'({clause["column"]} ne {self._coerce_value(clause)})'
        elif clause["operator"] == "GREATER_THAN":
            return f'({clause["column"]} gt {self._coerce_value(clause)})'
        elif clause["operator"] == "LESS_THAN":
            return f'({clause["column"]} lt {self._coerce_value(clause)})'
        elif clause["operator"] == "GREATER_OR_EQUAL":
            return f'({clause["column"]} ge {self._coerce_value(clause)})'
        elif clause["operator"] == "LESS_OR_EQUAL":
            return f'({clause["column"]} le {self._coerce_value(clause)})'
        elif clause["operator"] == "IN_ANY_OF":
            column_type = self._get_column_type(clause)
            if column_type in [ColumnType.String, ColumnType.Date]:
                # search.in() is more performant but only available for strings, see https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
                list_of_strings = self._coerce_value(clause)
                separator = self.choose_separator(list_of_strings)
                values = separator.join(list_of_strings)
                return f"search.in({clause['column']}, '{values}', '{separator}')"
            elif column_type in [ColumnType.Decimal, ColumnType.Integer]:
                list_of_numbers = super()._coerce_value(clause)
                values = [f'({clause["column"]} eq {v})' for v in list_of_numbers]
                return '(' + ' or '.join(values) + ')'
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "IN_NONE_OF":
            column_type = self._get_column_type(clause)
            if column_type in [ColumnType.String, ColumnType.Date]:
                list_of_strings = self._coerce_value(clause)
                separator = self.choose_separator(list_of_strings)
                values = separator.join(list_of_strings)
                return f"(not search.in({clause['column']}, '{values}', '{separator}'))"
            elif column_type in [ColumnType.Decimal, ColumnType.Integer]:
                list_of_numbers = super()._coerce_value(clause)
                values = [f'({clause["column"]} ne {v})' for v in list_of_numbers]
                return '(' + ' and '.join(values) + ')'
            else:
                raise self._unsupported_type_for_operator(column_type, clause['operator'])
        elif clause["operator"] == "CONTAINS":
            return f'search.ismatch(\'{self.escape_string(clause["value"])}\', \'{clause["column"]}\')'
        elif clause["operator"] == "AND":
            return '(' + ' and '.join([self.convert(x) for x in clause["clauses"]]) + ')'
        elif clause["operator"] == "OR":
            return '(' + ' or '.join([self.convert(x) for x in clause["clauses"]]) + ')'
        # custom operator to search a string in a list of strings, used for security tokens, not exposed to users
        elif clause["operator"] == "ARRAY_CONTAINS":
            return f'{clause["column"]}/any(v: v eq \'{self.escape_string(clause["value"])}\')'
        else:
            raise Exception("Unsupported filter operator for Azure AI Search: %s" % clause["operator"])

    def _and_filter(self, converted_clause_a: Dict, converted_clause_b: Dict) -> str:
        return f'({converted_clause_a} and {converted_clause_b})'

    def _get_filter_field_name(self) -> str:
        return "filters"

    def add_security_tokens_to_document(self, document: Document) -> Document:
        """ Add a field with the list of security tokens useful to implement add_security_filter()
        """
        if DKU_SECURITY_TOKENS_META in document.metadata:
            document.metadata[DKU_SECURITY_TOKENS_META] = json.loads(document.metadata[DKU_SECURITY_TOKENS_META])
            logging.info("Updated metadata to %s" % self._sanitize_metadata_for_print(document.metadata))

        return document

    def _get_security_token_check_clause(self, security_token: str) -> Dict:
        return {
            "operator": "ARRAY_CONTAINS",
            "column": DKU_SECURITY_TOKENS_META,
            "value": security_token,
        }


def get_azure_data_type(dss_storage_type: str, column_name: str) -> SearchFieldDataType:
    # Maps a dss storage type to an Azure SearchFieldDataType
    # The list of DSS storage types is in the Java class: com.dataiku.dip.datasets.Type
    # The list of Azure data types: https://learn.microsoft.com/en-us/dotnet/api/azure.search.documents.indexes.models.searchfielddatatype?view=azure-dotnet#properties
    if dss_storage_type == "string":
        return SearchFieldDataType.String
    elif dss_storage_type in ["date", "dateandtime", "dateonly"]:
        return SearchFieldDataType.DateTimeOffset
    elif dss_storage_type in ["geopoint", "geometry"]:
        return SearchFieldDataType.String # Not using Azure's geometry type because this would require converting the data into Azure's format
    elif dss_storage_type in ["array", "map", "object"]:
        # TODO @azureai-filtering Possibly add schema for complex fields (NB: Needs Azure API version >= 2019-05-06)
        #  https://learn.microsoft.com/en-gb/azure/search/search-howto-complex-data-types?tabs=portal
        #  https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/search/azure-search-documents#creating-an-index
        return SearchFieldDataType.String
    elif dss_storage_type == "boolean":
        return SearchFieldDataType.Boolean
    elif dss_storage_type in ["double", "float"]:
        return SearchFieldDataType.Double
    elif dss_storage_type == "bigint":
        return SearchFieldDataType.Int64
    elif dss_storage_type in ["int", "smallint", "tinyint"]:
        return SearchFieldDataType.Int32

    logging.warning(f"Unknown storage type {dss_storage_type} for metadata column {column_name}, attempting to use SearchFieldDataType.String")
    return SearchFieldDataType.String
