import io
import json
import os
import re
from typing import Any, Dict, List, Tuple

import dataiku
from backend.config import (
    get_guardrails_enabled,
    get_guardrails_pattern,
    get_uploads_managedfolder_id,
)
from backend.utils.logging_utils import get_logger
from backend.database.user_store_protocol import IUserStore
from backend.models.events import EventKind
from backend.services.helpers.guardrails_matcher import GuardrailsMatcher
try:
    from dataikuapi.dss.document_extractor import DocumentExtractor, ManagedFolderDocumentRef
except ImportError:  # pragma: no cover
    DocumentExtractor = None  # type: ignore
    ManagedFolderDocumentRef = None  # type: ignore

logger = get_logger(__name__)


def extract_document_text_content(
    folder, 
    file_path: str, 
    file_ext: str, 
    doc_extractor=None, 
    folder_id: str = None
) -> str:
    """
    Unified extraction logic for raw text files and complex documents.
    Returns the text content as a plain string.
    """
    text_content = ""
    
    # 1. Handle simple text files
    if file_ext in {"txt", "md", "html"}:
        try:
            with folder.get_file(file_path) as response:
                text_content = response.raw.read().decode("utf-8", errors="replace")
        except Exception as err:
            logger.warning(f"Failed to read text file {file_path}: {err}")
            
    # 2. Handle complex files (PDF/Images)
    elif file_ext in {"pdf", "docx", "pptx", "png", "jpg", "jpeg"}:
        if not doc_extractor or not folder_id:
            logger.warning("DocumentExtractor required but not provided for complex file.")
            return ""

        try:
            document_ref = ManagedFolderDocumentRef(file_path, folder_id)
            response = doc_extractor.text_extract(document=document_ref, image_handling_mode="OCR")
            response._fail_unless_success()

            # Use text_content property directly (plain string)
            text_content = response.text_content or ""

            # Fallback for images where text is in description (if text_content is empty)
            if not text_content and response.content and "description" in response.content:
                text_content = response.content["description"]
        except Exception as err:
            logger.warning(f"Failed to extract text from {file_path}: {err}")

    return text_content


def verify_document_content(
    text_content: str,
    matcher: GuardrailsMatcher,
    cache_folder=None,
    sidecar_filename: str = None
) -> Tuple[bool, str]:
    """
    Checks content against matcher. Handles cache reading/writing if folder provided.
    Returns: (is_violation, failure_message)
    """
    
    # 1. Check Cache
    if cache_folder and sidecar_filename:
        try:
            with cache_folder.get_download_stream(sidecar_filename) as stream:
                cached_data = json.load(stream)
                if cached_data.get("checked_pattern") == matcher.raw_config:
                    result = cached_data.get("result")
                    if result == "PASS":
                        return False, ""
                    if result == "FAIL":
                        return True, "Restricted content detected (cached)"
        except Exception:
            pass  # Cache miss

    # 2. Perform Check
    violation = matcher.check_structure(text_content)

    # 3. Write Cache
    if cache_folder and sidecar_filename:
        try:
            result_data = {
                "result": "FAIL" if violation else "PASS",
                "checked_pattern": matcher.raw_config
            }
            cache_folder.put_file(sidecar_filename, io.BytesIO(json.dumps(result_data).encode("utf-8")))
        except Exception as e:
            logger.warning(f"Failed to write guardrails sidecar: {e}")

    return violation, "Restricted content detected" if violation else ""


# --- Main Service Functions ---

def check_guardrails(documents: list[dict]) -> List[Dict[str, Any]]:
    """
    Check conversation documents against guardrails rules.
    """
    failed_documents: List[Dict[str, Any]] = []
    raw_config = get_guardrails_pattern()

    if not raw_config:
        return failed_documents

    matcher = GuardrailsMatcher(raw_config)
    
    try:
        folder_id = get_uploads_managedfolder_id()
        folder = dataiku.Folder(folder_id)

        for doc in documents:
            doc_name = doc.get("name") or "Document"
            document_path = doc.get("document_path", "")
            doc_text = doc.get("text", "")
            
            # Determine sidecar path based on metadata
            metadata = doc.get("metadata", {})
            text_path = metadata.get("text_path")
            sidecar_path = f"{text_path}_guardrails.json" if (text_path and text_path != "none") else None

            is_violation, _ = verify_document_content(
                doc_text, matcher, folder, sidecar_path
            )

            if is_violation:
                failed_documents.append({"name": doc_name, "document_path": document_path})

    except Exception as e:
        logger.exception(f"Error during guardrails check: {e}")

    return failed_documents


def process_documents_for_guardrails(
    store,
    conv_id: str,
    attachments: List[Dict[str, Any]],
) -> Dict[str, List[Dict[str, Any]]]:
    """
    Process conversation documents: Extract -> Check -> Update DB.
    If has already document has already a text extracted and stored in DB, reuse it for guardrails check. Otherwise, extract text content and then check against guardrails rules.
    """
    if not get_guardrails_enabled() or not get_guardrails_pattern() or not attachments:
        return {"content_violations": [], "extraction_failures": [], "all_checked": []}

    if DocumentExtractor is None:
        return {"content_violations": [], "extraction_failures": [], "all_checked": []}

    folder_id = get_uploads_managedfolder_id()
    client = dataiku.api_client()
    project = client.get_default_project()
    project_folder = project.get_managed_folder(folder_id)
    doc_extractor = DocumentExtractor(client, project.project_key)

    # Load existing DB records to reuse extraction
    from backend.schemas import schemas
    derived_docs: List[schemas.DerivedDocument] | None = store.get_derived_documents(conv_id) or []
    existing_docs = {doc.document_path: doc for doc in derived_docs}

    documents_for_check: List[Dict[str, Any]] = []
    extraction_failed_documents: List[Dict[str, Any]] = []

    for attachment in attachments:
        document_path = attachment.get("document_path")
        document_name = attachment.get("document_name") or attachment.get("name") or "Document"
        if not document_path: 
            continue

        file_ext = os.path.splitext(document_name.lower())[1].lstrip(".")
        
        # Check if text already extracted in DB record
        doc_record = existing_docs.get(document_path)
        doc_meta = doc_record.document_metadata if doc_record else {}
        existing_text_path = doc_meta.get("text_path")

        text_content = ""
        text_path = None

        # 1. Try reusing existing text file
        if existing_text_path:
            try:
                with project_folder.get_file(existing_text_path) as response:
                    text_content = response.raw.read().decode("utf-8", errors="replace")
                    text_path = existing_text_path
            except Exception:
                pass

        # 2. Extract if missing
        if not text_content:
            # Note: For conversation attachments, we extract into "outputs/"
            # This logic remains specific to conversation flow, so we keep path logic here
            # but reuse the extraction logic
            if document_path.startswith("inputs/"):
                output_base = document_path.replace("inputs/", "outputs/", 1)
            else:
                output_base = f"outputs/{document_path}"
            
            potential_text_path = f"{output_base}/guardrails_extracted.json"
            
            # Check if file exists on disk (but wasn't in DB)
            try:
                with project_folder.get_file(potential_text_path) as r:
                    text_content = r.raw.read().decode("utf-8", errors="replace")
                    text_path = potential_text_path
            except Exception:
                pass

            if not text_content:
                # Perform fresh extraction
                text_content = extract_document_text_content(
                    project_folder, document_path, file_ext, doc_extractor, folder_id
                )
                if text_content and file_ext not in {"txt", "md", "html"}:
                    # Persist extraction for complex files
                    try:
                        project_folder.put_file(potential_text_path, io.BytesIO(text_content.encode("utf-8")))
                        text_path = potential_text_path
                    except Exception:
                        text_path = "none"
                else:
                    text_path = "none"

        if text_content:
            documents_for_check.append({
                "name": document_name,
                "document_path": document_path,
                "text": text_content,
                "metadata": {"text_path": text_path},
            })
        else:
            # Text extraction failed - mark document as blocked
            extraction_failed_documents.append(
                {
                    "name": document_name,
                    "document_path": document_path,
                }
            )
            logger.warning(f"Text extraction failed for document {document_name} at {document_path}")

    # Run Checks
    content_violation_documents = check_guardrails(documents_for_check)

    content_violation_paths = {doc.get("document_path") for doc in content_violation_documents}
    all_checked_documents = []

    # Update DB for documents that were checked
    for checked_doc in documents_for_check:
        doc_path = checked_doc.get("document_path")
        if doc_path in existing_docs:
            doc_record = existing_docs[doc_path]
            doc_meta = doc_record.document_metadata or {}

            if doc_path in content_violation_paths:
                doc_meta["guardrails_status"] = "content_violation"
                doc_meta["guardrails_error"] = "Document contains restricted content"
            else:
                doc_meta["guardrails_status"] = "passed"
                doc_meta.pop("guardrails_error", None)

            store.upsert_derived_document(conv_id, doc_record.document_name, doc_path, doc_meta)
            all_checked_documents.append(
                {
                    "id": doc_record.id,
                    "name": checked_doc.get("name"),
                    "document_path": doc_path,
                    "guardrails_status": doc_meta.get("guardrails_status"),
                }
            )

    # Update DB for extraction failures
    for failed_doc in extraction_failed_documents:
        doc_path = failed_doc.get("document_path")
        if doc_path in existing_docs:
            doc_record = existing_docs[doc_path]
            doc_meta = doc_record.document_metadata or {}

            doc_meta["guardrails_status"] = "extraction_failed"
            doc_meta["guardrails_error"] = "Unable to extract text from document"

            store.upsert_derived_document(conv_id, doc_record.document_name, doc_path, doc_meta)
            all_checked_documents.append(
                {
                    "id": doc_record.id,
                    "name": failed_doc.get("name"),
                    "document_path": doc_path,
                    "guardrails_status": "extraction_failed",
                }
            )

    return {
        "content_violations": content_violation_documents,
        "extraction_failures": extraction_failed_documents,
        "all_checked": all_checked_documents,
    }


def check_agent_documents_guardrails(
    agent_id: str,
    documents: List[Dict[str, Any]],
) -> Dict[str, Any]:
    """
    Check agent documents before indexing.
    """
    return _process_agent_docs(agent_id, documents, use_published=False)


def check_agent_guardrails_at_runtime(
    agent_id: str,
    documents: List[Dict[str, Any]],
    use_published: bool = False,
) -> Dict[str, Any]:
    """
    Check agent documents during chat (runtime).
    """
    return _process_agent_docs(agent_id, documents, use_published=use_published, runtime_mode=True)


def _process_agent_docs(
    agent_id: str, 
    documents: List[Dict[str, Any]], 
    use_published: bool, 
    runtime_mode: bool = False
) -> Dict[str, Any]:
    """
    Unified handler for Agent document checking (both Indexing and Runtime).
    """
    from backend.constants import DRAFT_ZONE, PUBLISHED_ZONE
    from backend.services.agent_assets import get_ua_project_folder
    from backend.utils.project_utils import get_managed_folder_in_zone, get_ua_project

    result = {
        "content_violations": [],
        "extraction_failures": [],
        "passed": [],
        "has_violations": False,
    }

    if not get_guardrails_enabled():
        result["passed"] = documents
        return result
        
    raw_pattern = get_guardrails_pattern()
    if not raw_pattern:
        result["passed"] = documents
        return result

    # Filter documents based on mode
    if runtime_mode:
        docs_to_check = [d for d in documents if d.get("active") and not d.get("deletePending")]
    else:
        docs_to_check = [d for d in documents if not d.get("deletePending")]

    if not docs_to_check:
        return result

    # Setup Folders
    try:
        if runtime_mode:
            zone = PUBLISHED_ZONE if use_published else DRAFT_ZONE
            project = get_ua_project(agent_id)
            documents_folder = get_managed_folder_in_zone(project, f"documents_{zone}", zone)
            extracted_folder = get_managed_folder_in_zone(project, f"extracted_documents_{zone}", zone)
            if not extracted_folder: # Attempt creation if missing
                 extracted_folder = _get_or_create_extracted_folder(project, zone)
        else:
            project, documents_folder = get_ua_project_folder(agent_id)
            extracted_folder = _get_or_create_extracted_folder(project, DRAFT_ZONE)
        
        if not documents_folder:
            result["passed"] = documents
            return result
            
        documents_folder_id = documents_folder.id
    except Exception as e:
        logger.exception(f"Failed to setup folders for agent guardrails: {e}")
        result["passed"] = documents
        return result

    # Setup Extractor
    client = dataiku.api_client()
    doc_extractor = DocumentExtractor(client, project.project_key) if DocumentExtractor else None
    matcher = GuardrailsMatcher(raw_pattern)

    content_violation_names = set()
    extraction_failure_names = set()

    for doc in docs_to_check:
        doc_name = doc.get("name", "")
        if not doc_name: continue

        file_ext = os.path.splitext(doc_name.lower())[1].lstrip(".")
        text_content = ""
        cache_filename = f"{doc_name}_extracted.json"
        
        # 1. Try Cache
        if extracted_folder:
            try:
                with extracted_folder.get_file(cache_filename) as response:
                    text_content = response.raw.read().decode("utf-8", errors="replace")
            except Exception:
                pass

        # 2. Extract if needed
        if not text_content:
            text_content = extract_document_text_content(
                documents_folder, doc_name, file_ext, doc_extractor, documents_folder_id
            )
            # Update cache
            if text_content and extracted_folder:
                try:
                    extracted_folder.put_file(cache_filename, io.BytesIO(text_content.encode("utf-8")))
                except Exception as e:
                    logger.warning(f"Failed to cache extraction for {doc_name}: {e}")

        # 3. Check Content or mark extraction failure
        if text_content:
            sidecar_file = f"{doc_name}_guardrails.json"
            is_violation, _ = verify_document_content(text_content, matcher, extracted_folder, sidecar_file)
            if is_violation:
                content_violation_names.add(doc_name)
        else:
            # Text extraction failed - mark document as blocked
            extraction_failure_names.add(doc_name)
            logger.warning(f"Text extraction failed for agent document {doc_name}")

    # Sort Results
    for doc in docs_to_check:
        doc_name = doc.get("name")
        if doc_name in content_violation_names:
            result["content_violations"].append(doc)
        elif doc_name in extraction_failure_names:
            result["extraction_failures"].append(doc)
        else:
            result["passed"].append(doc)

    result["has_violations"] = (len(result["content_violations"]) + len(result["extraction_failures"])) > 0
    return result


def _get_or_create_extracted_folder(project, zone_name: str):
    """
    Get or create the extracted_documents folder for caching.
    """
    from backend.config import get_config
    from backend.utils.project_utils import get_managed_folder_in_zone, get_zone_id_by_name

    folder_name = f"extracted_documents_{zone_name}"
    existing_folder = get_managed_folder_in_zone(project, folder_name, zone_name)
    if existing_folder:
        return existing_folder

    try:
        connection = get_config().get("default_fs_connection")
        folder = project.create_managed_folder(folder_name, connection_name=connection)
        zone_id = get_zone_id_by_name(project, zone_name)
        if zone_id:
            project.get_flow().get_zone(zone_id).add_item(folder)
        return folder
    except Exception as e:
        logger.warning(f"Failed to create extracted_documents folder: {e}")
        return None


def enforce_indexing_guardrails(agent_id, docs: List, store: IUserStore, remove_document_fn):
    """
    Checks documents against guardrails policies.
    Returns a tuple (error_response, http_code) if violations occur, or (None, None) if passed.

    Violating documents are deleted using the provided remove_document_fn.

    Args:
        agent_id: The agent ID
        docs: List of document dicts
        store: The user store instance
        remove_document_fn: Callable(agent, filename) to delete documents
    """
    guardrails_result = {"content_violations": [], "extraction_failures": [], "passed": []}
    if get_guardrails_enabled() and get_guardrails_pattern():
        guardrails_result = check_agent_documents_guardrails(agent_id, docs)

    content_violations = guardrails_result.get("content_violations", [])
    extraction_failures = guardrails_result.get("extraction_failures", [])

    if content_violations or extraction_failures:
        content_violation_names = [d.get("name") for d in content_violations]
        extraction_failure_names = [d.get("name") for d in extraction_failures]
        violated_filenames = set(content_violation_names + extraction_failure_names)

        # Delete violating documents
        agent = store.get_agent(agent_id)
        for filename in violated_filenames:
            try:
                remove_document_fn(agent, filename)
                logger.info("Deleted guardrails-violating document %s for agent %s", filename, agent_id)
            except Exception as e:
                logger.warning("Failed to delete guardrails-violating document %s: %s", filename, e)

        # Remove violated documents from list and update store
        filtered_docs = [d for d in docs if d.get("name") not in violated_filenames]
        store.update_agent(agent_id, {"documents": filtered_docs})

        # Build error response
        error_parts = []
        if content_violations:
            error_parts.append(f"content policy violations: {', '.join(content_violation_names)}")
        if extraction_failures:
            error_parts.append(f"text extraction failures: {', '.join(extraction_failure_names)}")

        error_message = f"Documents blocked due to {' and '.join(error_parts)}"
        logger.warning("Guardrails blocked documents for the agent.")

        return (
            {
                "error": "guardrails_violation",
                "message": error_message,
                "content_violations": content_violation_names,
                "extraction_failures": extraction_failure_names,
                "documents": filtered_docs,
            },
            422,  # Unprocessable Entity
        )

    return None, None

def emit_guardrails_filter_events(blocked_agents: list[dict], pcb) -> None:
        """
        Emit guardrails violation/extraction failure events for filtering scenario (multi-agent mode).
        Events include 'filtered': True to indicate agents were filtered out, not blocked.
        """
        content_violation_parts = []
        extraction_failure_parts = []

        for ba in blocked_agents:
            agent_name = ba["agentName"]
            content_violations = ba.get("content_violations", [])
            extraction_failures = ba.get("extraction_failures", [])

            if content_violations:
                content_violation_parts.append(f"'{agent_name}' ({', '.join(content_violations)})")
            if extraction_failures:
                extraction_failure_parts.append(f"'{agent_name}' ({', '.join(extraction_failures)})")

        if content_violation_parts:
            msg = f"Filtered agents with knowledge base documents blocked by content policy: {'; '.join(content_violation_parts)}"
            pcb(
                {
                    "eventKind": EventKind.GUARDRAILS_VIOLATION,
                    "eventData": {
                        "blocked_agents": blocked_agents,
                        "message": msg,
                        "filtered": True,
                    },
                }
            )
            logger.warning(msg)

        if extraction_failure_parts:
            msg = f"Filtered agents with text extraction failures in knowledge base documents: {'; '.join(extraction_failure_parts)}"
            pcb(
                {
                    "eventKind": EventKind.GUARDRAILS_EXTRACTION_FAILED,
                    "eventData": {
                        "blocked_agents": blocked_agents,
                        "message": msg,
                        "filtered": True,
                    },
                }
            )
            logger.warning(msg)