import asyncio
import logging
import os
import shutil
import tempfile
import threading
from datetime import datetime
from pathlib import Path

import dataiku
from dataiku.customrecipe import (
    get_input_names_for_role,
    get_output_names_for_role,
    get_recipe_config,
    get_recipe_resource,
)
from dku_graphrag.index.dataiku_graph_index_builder import DataikuGraphragIndexBuilder
from dku_graphrag.utils.graphrag_config import get_graphrag_config, update_prompt, update_prompt_temp


def create_temp_dir(prefix: str = None) -> Path:
    base_temp = Path(tempfile.gettempdir())
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    thread_id = threading.get_ident()
    dir_name = f"graphrag_db_{timestamp}_thread_{thread_id}"
    if prefix:
        dir_name = f"{dir_name}_{prefix}"
    temp_path = base_temp / dir_name
    if temp_path.exists():
        shutil.rmtree(temp_path)
    temp_path.mkdir(parents=True, exist_ok=True)
    try:
        temp_path.chmod(0o777)
    except Exception as e:
        logger.warning(f"Permission setting skipped (reason: {e})")
    return temp_path


def get_temp_numba_cache_dir() -> Path:
    """
    Create (or retrieve) a persistent temporary directory for Numba caching.
    This function creates a folder with a fixed name in the system temp folder
    (e.g. /tmp/numba_cache on Unix or the equivalent on Windows). This allows
    reusing the cache across runs while avoiding creating files in the home directory.
    """
    base_temp = Path(tempfile.gettempdir())
    persistent_cache = base_temp / "numba_cache"
    persistent_cache.mkdir(parents=True, exist_ok=True)
    try:
        persistent_cache.chmod(0o777)
    except Exception as e:
        logger.warning(f"Permission setting skipped for numba cache (reason: {e})")
    return persistent_cache


custom_config = get_recipe_config()
input_dataset_name = get_input_names_for_role("input_dataset")[0]
output_folder_name = get_output_names_for_role("output_folder")[0]
verbose_mode = custom_config.get("verbose_mode", False)


text_column = custom_config.get("text_column")
source_column = custom_config.get("source_column")
title_column = custom_config.get("title_column")
attribute_columns = custom_config.get("attribute_columns", [])
chunk_size = custom_config.get("chunk_size")
chunk_overlap = custom_config.get("chunk_overlap")
group_by_columns = custom_config.get("group_by_column", [])

# creates the temporary directory for working
temp_root_dir = create_temp_dir()

persistent_numba_cache = get_temp_numba_cache_dir()
os.environ["NUMBA_CACHE_DIR"] = str(persistent_numba_cache)
from numba import jit

chat_completion_llm_id = custom_config.get("chat_completion_llm_id")
embedding_llm_id = custom_config.get("embedding_llm_id")

logger = logging.getLogger("graphrag recipe")
logging.basicConfig(
    level=logging.DEBUG if verbose_mode else logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    force=True,
)


input_dataset = dataiku.Dataset(input_dataset_name)
output_folder = dataiku.Folder(output_folder_name)
output_folder.clear()

resource_folder_path = get_recipe_resource()
settings_source = Path(resource_folder_path) / "settings.yaml"
settings_target = temp_root_dir / "settings.yaml"

settings_content = settings_source.read_bytes()

settings_target.write_bytes(settings_content)

logger.debug(f"Copied new settings.yaml file from resources to {settings_target}")


"""
# todo: is this necessary?

prompts_source = os.path.join(resource_folder_path, "prompts")
for root, _, files in os.walk(prompts_source):
    for file in files:
        relative_path = os.path.relpath(os.path.join(root, file), prompts_source)
        with open(os.path.join(root, file), "rb") as f:
            file_content = f.read()
        with output_folder.get_writer(os.path.join("prompts", relative_path)) as writer:
            writer.write(file_content)
"""
prompts_source = os.path.join(resource_folder_path, "prompts")

for root, _, files in os.walk(prompts_source):
    for file in files:
        relative_path = os.path.relpath(os.path.join(root, file), prompts_source)
        target_file_path = os.path.join(temp_root_dir, "prompts", relative_path)
        os.makedirs(os.path.dirname(target_file_path), exist_ok=True)
        shutil.copy2(os.path.join(root, file), target_file_path)


selected_columns = [text_column] + attribute_columns
logger.info(f"Input dataset selected_columns: {selected_columns}")


df = input_dataset.get_dataframe(columns=selected_columns)


input_file_rel_path = temp_root_dir / "input" / f"{input_dataset_name}.csv"
input_file_rel_path.parent.mkdir(parents=True, exist_ok=True)

df.to_csv(input_file_rel_path, index=False, encoding="utf-8")

logger.debug(f"Saved input dataset CSV at {input_file_rel_path}")


temp_output_dir = temp_root_dir / "output"
temp_output_dir.mkdir(parents=True, exist_ok=True)

logger.debug(f"Selected columns for indexing: {selected_columns}")


graph_rag_config = get_graphrag_config(custom_config, temp_root_dir)

temp_db_dir = temp_root_dir / graph_rag_config.embeddings.vector_store["db_uri"]
graph_rag_config.embeddings.vector_store["db_uri"] = temp_db_dir
logger.info(f"Setting vector db absolite path: {temp_db_dir}")


# update_prompt(output_folder, "entity_extraction.txt", custom_config.get("entity_extraction_prompt"))
# update_prompt(output_folder, "summarize_descriptions.txt", custom_config.get("summarize_descriptions_prompt"))
# update_prompt(output_folder, "community_report.txt", custom_config.get("community_report_prompt"))
# update_prompt(output_folder, "claim_extraction.txt", custom_config.get("claim_extraction_prompt"))

update_prompt_temp(temp_root_dir, "entity_extraction.txt", custom_config.get("entity_extraction_prompt"))
update_prompt_temp(temp_root_dir, "summarize_descriptions.txt", custom_config.get("summarize_descriptions_prompt"))
update_prompt_temp(temp_root_dir, "community_report.txt", custom_config.get("community_report_prompt"))
update_prompt_temp(temp_root_dir, "claim_extraction.txt", custom_config.get("claim_extraction_prompt"))

logger.debug(f"graphrag config is {graph_rag_config}")

builder = DataikuGraphragIndexBuilder(chat_completion_llm_id, embedding_llm_id)
asyncio.run(
    builder.run_build_index_pipeline(config=graph_rag_config, verbose=verbose_mode, resume=None, memprofile=False)
)


def upload_to_managed_folder(managed_folder, local_path, target_path=""):
    local_path = Path(local_path)

    if local_path.is_dir():
        for item in local_path.iterdir():
            new_target = f"{target_path}/{item.name}" if target_path else item.name
            upload_to_managed_folder(managed_folder, item, new_target)
    else:
        with local_path.open("rb") as file_stream:
            managed_folder.upload_stream(target_path, file_stream)
        logger.info(f"Uploaded {local_path} to managed folder at {target_path}")


upload_to_managed_folder(output_folder, temp_root_dir, "")
# Delete temporary directory
shutil.rmtree(temp_root_dir)
logger.info(f"Deleted temporary directory: {temp_root_dir}")
