import asyncio
import itertools
import logging
import os
import traceback
import sys
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import List

from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult
from presidio_analyzer.nlp_engine import NlpEngineProvider
from presidio_anonymizer import AnonymizerEngine
from presidio_anonymizer.entities import OperatorConfig

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.batcher import Batcher
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.llm.types import ImageGenerationPrompt, ProcessSinglePromptCommand, ProcessCompletionResponseCommand, ProcessSingleEmbeddingCommand, ProcessSingleImageGenerationCommand

from dataiku.base.utils import watch_stdin

logger = logging.getLogger(__name__)

DEFAULT_LANGUAGE = "en"

MODELS = {
    "en": {"lang_code": "en", "model_name": "en_core_web_md"},
    "fr": {"lang_code": "fr", "model_name": "fr_core_news_sm"},
    "de": {"lang_code": "de", "model_name": "de_core_news_md"},
    "it": {"lang_code": "it", "model_name": "it_core_news_sm"},
    "nl": {"lang_code": "nl", "model_name": "nl_core_news_sm"},
    "es": {"lang_code": "es", "model_name": "es_core_news_sm"},
    "ja": {"lang_code": "ja", "model_name": "ja_core_news_md"}
}

def ensure_language_support(language: str):
    if language not in MODELS:
        raise Exception("Unsupported language: %s" % language)
    if language == "ja" and sys.version_info.major == 3 and sys.version_info.minor < 9:
        raise Exception("Japanese support requires Python 3.9 or later. Upgrade your internal code-env python version or remove the 'ja' from the supported languages list in your Guardrail config.")


@dataclass
class AnalyzerRequest:
    text: str
    language: str


class PresidioServer:

    def __init__(self):
        self.anonymizer_operators: Dict[str, OperatorConfig] = None
        self.anonymizer: AnonymizerEngine = None
        self.language_detector: Callable[[str], str] = None
        self.entities: Dict[str, str] = None
        self.settings: Dict = None
        self.analyzer: AnalyzerEngine = None
        self.batcher: Batcher[AnalyzerRequest, List[RecognizerResult]] = None
        self.started: bool = False
        self.executor = ThreadPoolExecutor(1)

    def handle_start_command(self, command):
        self.settings = command["settings"]

        if self.settings["detectionAction"] == "HASH":
            self.anonymizer_operators = {"DEFAULT": OperatorConfig("hash")}
        elif self.settings["detectionAction"] == "MASK":
            self.anonymizer_operators = {
                "DEFAULT": OperatorConfig(
                    "mask",
                    {
                        "masking_char": "*",
                        "from_end": True,
                        "chars_to_mask": self.settings["charsToMask"],
                    },
                )
            }
        elif self.settings["detectionAction"] == "REDACT":
            self.anonymizer_operators = {"DEFAULT": OperatorConfig("redact")}
        else:
            self.anonymizer_operators = {}

        if "language" not in self.settings or len(self.settings["language"]) == 0:
            import langdetect

            langdetect.DetectorFactory.seed = 0
            self.language_detector = langdetect.detect

        if len(self.settings["supportedLanguages"]) == 0:
            self.settings["supportedLanguages"] = [DEFAULT_LANGUAGE]

        models = []
        for language in self.settings["supportedLanguages"]:
            language = language.strip()
            if language != '':
                ensure_language_support(language)
                models.append(MODELS[language])

        provider = NlpEngineProvider(
            nlp_configuration={"nlp_engine_name": "spacy", "models": models}
        )
        nlp_engine = provider.create_engine()

        self.analyzer = AnalyzerEngine(
            default_score_threshold=self.settings["confidenceThreshold"],
            supported_languages=self.settings["supportedLanguages"],
            nlp_engine=nlp_engine,
        )
        self.anonymizer = AnonymizerEngine()

        def get_entities(language):
            entities = set()
            for rec in self.analyzer.registry.recognizers:
                if rec.supported_language == language:
                    for entity in rec.get_supported_entities():
                        entities.add(entity)
            return entities

        self.entities = {}  # leave empty for "ALL" entities mode
        if self.settings["entitiesMode"] == "EXPLICIT_INCLUDE":
            for language in self.settings["supportedLanguages"]:
                self.entities[language] = self.settings["includedEntities"]
        elif self.settings["entitiesMode"] == "EXPLICIT_EXCLUDE":
            excluded_entities = set(self.settings["excludedEntities"])
            for language in self.settings["supportedLanguages"]:
                self.entities[language] = list(get_entities(language) - excluded_entities)

        batch_analyzer = BatchAnalyzerEngine(self.analyzer)

        def _run_batch_sync(requests: List[AnalyzerRequest]) -> List[List[RecognizerResult]]:
            logger.info("Processing a batch of %s PII requests" % len(requests))
            language = requests[0].language
            return batch_analyzer.analyze_iterator(
                texts=map(lambda r: r.text, requests),
                language=language,
                entities=self.entities.get(language),
            )

        async def _process_batch_async(requests: List[AnalyzerRequest]) -> Awaitable[List[List[RecognizerResult]]]:
            return await asyncio.get_running_loop().run_in_executor(self.executor, _run_batch_sync, requests)

        self.batcher = Batcher[AnalyzerRequest, List[RecognizerResult]](
            batch_size=100,
            timeout=0,
            process_batch=_process_batch_async,
            group_by=lambda request: request.language
        )

        logger.info("Presidio started with settings: %s" % self.settings)
        self.started = True

    def _infer_language(self, text):
        if "language" in self.settings and len(self.settings["language"]) > 0:
            return self.settings["language"]

        try:
            return self.language_detector(text)
        except Exception as e:
            # we shouldn't fail the whole PII detection just because the language couldn't be inferred
            logger.info(
                "Falling back to default language for PII detection, language detection failed with error: %s"
                % e
            )
            # assume English by default if it is in the supported languages
            if DEFAULT_LANGUAGE in self.settings["supportedLanguages"]:
                return DEFAULT_LANGUAGE
            else:
                return self.settings["supportedLanguages"][0]

    def _redact_or_fail(self, text, results):
        """
        Returns a tuple (processed_text, [detected_entities])
        """
        if self.settings["detectionAction"] == "FAIL":
            error_detail = ", ".join(
                [
                    "'{}' ({})".format(
                        text[res.start : res.end],
                        res.entity_type,
                    )
                    for res in results
                ]
            )
            raise Exception(
                "PII was detected and denied: {}".format(error_detail)
            )
        elif self.settings["detectionAction"] == "FLAG_ONLY":
            return text

        return self.anonymizer.anonymize(
            text, results, self.anonymizer_operators
        ).text

    async def _process_text(self, text):
        """
        Returns a tuple (processed_text, [detected_entities])
        """
        logger.info("Analyzing: %s" % text)

        language = self._infer_language(text)  # langdetect is not thread-safe (https://github.com/Mimino666/langdetect/pull/33)
        logger.info("Language: %s" % language)

        if language not in self.settings["supportedLanguages"]:
            if self.settings["unsupportedLanguageAction"] == "FAIL":
                raise Exception(
                    "Language of the text is %s but PII detection is not supported for this language"
                    % language
                )
            logger.warning("Skipping PII detection: language %s not supported" % language)

        else:
            analyzer_results = await self.batcher.process(AnalyzerRequest(text, language))

            if len(analyzer_results) > 0:
                logger.info("PII was detected: %s" % analyzer_results)
                redacted_text = self._redact_or_fail(text, analyzer_results)
                recognized = [{"type": str(x.entity_type)} for x in analyzer_results]
                return (redacted_text, recognized)

        return (text, [])

    async def _process_message_part_text(self, part):
        """returns a list of detected entities"""
        (processed_text, entities) = await self._process_text(part["text"])
        part["text"] = processed_text
        return entities

    async def _process_message_content(self, message):
        """returns a list of detected entities"""
        (processed_text, entities)= await self._process_text(message["content"])
        message["content"]  = processed_text
        return entities

    async def handle_completion_query(self, request: ProcessSinglePromptCommand):
        assert self.started, "Presidio not started"

        coroutines = []
        for message in request["query"]["messages"]:
            for part in message.get("parts", []):
                if len(part.get("text", "").strip()) == 0:
                    logger.info("Skipping empty message part.")
                    continue

                coro = self._process_message_part_text(part)
                coroutines.append(coro)

            if len(message.get("content", "").strip()) == 0:
                logger.info("Skipping empty message content.")
                continue

            coro = self._process_message_content(message)
            coroutines.append(coro)

        tasks = [asyncio.create_task(coro) for coro in coroutines]
        entities_list_list = await asyncio.gather(*tasks)
        logger.info("Processing done, now messages: %s" % request["query"])
        logger.info("And recognized entities: %s" % entities_list_list)
        return {
            "redactedQuery": request["query"],
            "recognizedEntities": [e for el in entities_list_list for e in el]
        }

    async def handle_completion_response(self, request: ProcessCompletionResponseCommand):
        assert self.started, "Presidio not started"

        if (
            request["completionResponse"].get("text") is None
            or len(request["completionResponse"]["text"].strip()) == 0
        ):
            logger.info("Skipping empty text.")
            return request["completionResponse"]

        request["completionResponse"]["text"], entities_list = await self._process_text(request["completionResponse"]["text"])

        logger.info("Processing done, now completionResponse: %s" % request["completionResponse"])

        return {
            "redactedResponse": request["completionResponse"],
            "recognizedEntities": entities_list
        }


    async def handle_embedding_query(self, request: ProcessSingleEmbeddingCommand):
        assert self.started, "Presidio not started"

        if (
            request["query"].get("text") is None
            or len(request["query"]["text"].strip()) == 0
        ):
            logger.info("Skipping empty text.")
            return request["query"]

        request["query"]["text"], entities_list = await self._process_text(request["query"]["text"])

        logger.info("Processing done, now query: %s" % request["query"])

        return {
            "redactedQuery": request["query"],
            "recognizedEntities": entities_list
        }

    async def _process_image_gen_prompt(self, image_gen_prompt: ImageGenerationPrompt):
        (processed_text, entities_list) = await self._process_text(image_gen_prompt["prompt"])
        image_gen_prompt["prompt"] = processed_text
        return entities_list

    async def handle_image_generation_query(self, request: ProcessSingleImageGenerationCommand):
        assert self.started, "Presidio not started"

        prompt_iterator = itertools.chain(
            request["query"].get("prompts", []),
            request["query"].get("negativePrompts", [])
        )

        coroutines = []
        for image_gen_prompt in prompt_iterator:
            if len(image_gen_prompt.get("prompt", "").strip()) == 0:
                logger.info("Skipping empty prompt.")
                continue

            coro = self._process_image_gen_prompt(image_gen_prompt)
            coroutines.append(coro)

        tasks = [asyncio.create_task(coro) for coro in coroutines]
        entities_list_list = await asyncio.gather(*tasks)
        logger.info("Processing done, now image generation query: %s" % request["query"])

        return {
            "redactedQuery": request["query"],
            "recognizedEntities": [e for el in entities_list_list for e in el]
        }

    async def handler(self, command):
        if command["type"] == "start":
            logger.info("Received start command: %s" % command)
            yield await asyncio.get_running_loop().run_in_executor(self.executor, self.handle_start_command, command)
        elif command["type"] == "process-completion-query":
            logger.debug("Received query: %s", command)
            yield await self.handle_completion_query(command)
        elif command["type"] == "process-completion-response":
            logger.debug("Received query: %s", command)
            yield await self.handle_completion_response(command)
        elif command["type"] == "process-embedding-query":
            logger.debug("Received query: %s", command)
            yield await self.handle_embedding_query(command)
        elif command["type"] == "process-image-generation-query":
            logger.debug(f"Received query: {command}")
            yield await self.handle_image_generation_query(command)
        else:
            raise Exception("Unknown command type: %s" % command["type"])


def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')

    # presidio logging induces significant perf overhead
    logging.getLogger("presidio-analyzer").setLevel(logging.WARNING)
    logging.getLogger("presidio-anonymizer").setLevel(logging.WARNING)
    watch_stdin()


    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = PresidioServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server())
