import asyncio
import base64
import io
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Dict

from docxtpl import DocxTemplate

import dataiku
from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.llm.python.blocks_graph.microcel import MicroCelEngine
from dataiku.llm.python.blocks_graph.utils import interpolate_cel
from dataikuapi.dss.document_extractor import ManagedFolderDocumentRef, InlineDocumentRef


logger = logging.getLogger("document_templating_server")


class DocumentTemplatingServer:
    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(16)
        self.lock = threading.Lock()
        self.run_counter = 1

    def start(self, command):
        """Initialize the server (no special setup needed for docxtpl)"""
        self.started = True
        logger.info("Document templating server started")

    def render(self, command: Dict) -> Dict:
        """Render a document template with the provided variables"""
        logger.info(f"Rendering template (run {self.run_counter})")

        with self.lock:
            config = command["config"]
            template_type = config["templateType"]
            template = self._read_template(config.get("templateRef"))
            if template_type == "CEL_EXPANSION":
                if isinstance(template, bytes):
                    template = template.decode("utf-8")
                output_str = self._render_cel(template, config["data"])
                return self._write_output(
                    config.get("destinationRef"),
                    output_str=output_str,
                    mime_type="text/plain")
            elif template_type == "JINJA":
                if isinstance(template, bytes):
                    template = template.decode("utf-8")
                output_str = self._render_jinja(template, config["data"])
                return self._write_output(
                    config.get("destinationRef"),
                    output_str=output_str,
                    mime_type="text/plain")
            elif template_type == "DOCX_JINJA":
                if isinstance(template, bytes):
                    template = base64.b64encode(template).decode('utf-8')
                output_base64 = self._render_docx_template(template, config["data"])
                return self._write_output(
                    config.get("destinationRef"),
                    output_base64=output_base64,
                    mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
            else:
                raise ValueError(f"Unsupported template type: {template_type}")

    def _read_template(self, template_ref: dict):
        if not template_ref:
            raise ValueError("Missing input template ref")

        template_ref_type = template_ref.get("type")
        if template_ref_type == "inline_document":
            return template_ref.get("content")
        elif template_ref_type == "managed_folder":
            managed_folder_id = template_ref.get("managedFolderId")
            if not managed_folder_id:
                raise ValueError("Missing managedFolderId in input document ref")
            input_template_path = template_ref.get("filePath")
            if not input_template_path:
                raise ValueError("Missing filePath in input document ref")
            folder = dataiku.Folder(managed_folder_id)
            with folder.get_download_stream(input_template_path) as stream:
                return stream.read()
        else:
            return ValueError("Unsupported document ref type %s", template_ref_type)

    def _render_cel(self, template_string: str, data: dict) -> str:
        if not template_string:
            raise ValueError("A template must be specified for text templating")
        engine = MicroCelEngine(data)
        return interpolate_cel(engine, template_string)

    def _render_jinja(self, template_string: str, data: dict) -> str:
        if not template_string:
            raise ValueError("A template must be specified for text templating")
        from jinja2 import Template
        return Template(template_string).render(**data)

    def _render_docx_template(self, template_base64: str, data: dict) -> str:
        if not template_base64:
            raise ValueError("Missing template for DOCX templating")
        template_bytes = base64.b64decode(template_base64)

        # Decode template
        template_stream = io.BytesIO(template_bytes)

        # Load and render
        doc = DocxTemplate(template_stream)
        doc.render(data)

        # Save to bytes
        output_stream = io.BytesIO()
        doc.save(output_stream)
        output_bytes = output_stream.getvalue()

        # Encode result
        output_b64 = base64.b64encode(output_bytes).decode('utf-8')

        return output_b64

    def _write_output(self, destination_ref: dict, output_str=None, output_base64=None, mime_type=None) -> dict:
        if not destination_ref:
            raise ValueError("Missing destination ref")

        destination_ref_type = destination_ref.get("type")
        if destination_ref.get("type") == "inline_document":
            if output_str:
                return {
                    "documentRef": InlineDocumentRef(
                        output_str,
                        InlineDocumentRef.CONTENT_TYPE_PLAIN_TEXT,
                        mime_type=mime_type).as_json()
                }
            elif output_base64:
                return {
                    "documentRef": InlineDocumentRef(
                        output_base64,
                        InlineDocumentRef.CONTENT_TYPE_BASE64_BYTES,
                        mime_type=mime_type).as_json()
                }
            else:
                raise ValueError("No output content found for document templating inline output")
        elif destination_ref.get("type") == "managed_folder":
            managed_folder_id = destination_ref.get("managedFolderId")
            if not managed_folder_id:
                raise ValueError("Missing managedFolderId in destination document ref")
            destination_filepath = destination_ref.get("filePath")
            if not destination_filepath:
                raise ValueError("Missing filePath in input destination ref")

            if output_str:
                file_content = output_str.encode("utf-8")
            elif output_base64:
                file_content = base64.b64decode(output_base64)
            else:
                raise ValueError("No output content found for document templating file output")

            folder = dataiku.Folder(managed_folder_id)
            folder.upload_data(destination_filepath, file_content)
            return {
                "documentRef": ManagedFolderDocumentRef(
                    destination_filepath,
                    managed_folder_id,
                    mime_type=mime_type).as_json()
            }
        else:
            raise ValueError("Unsupported document ref type %s", destination_ref_type)

    async def handler(self, command):
        if command["type"] == "start":
            logger.info("Received start command")
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "render":
            logger.info(f"=== Start rendering - run {self.run_counter} ===")
            try:
                yield await asyncio.get_running_loop().run_in_executor(
                    self.executor, self.render, command
                )
                logger.info(f"=== End rendering - run {self.run_counter} ===")
            finally:
                self.run_counter += 1
        else:
            raise Exception(f"Unknown command type: {command['type']}")


if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()
    logging.basicConfig(
        level=LOGLEVEL,
        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s'
    )

    watch_stdin()

    async def start_server():
        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = DocumentTemplatingServer()

        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
