import json
import logging
import os
import shutil
from pathlib import Path

import dataiku
from graphrag.config.load_config import load_config

logger = logging.getLogger("graphrag utils")


def update_prompt(output_folder, file_name: str, content_value: str) -> None:
    try:
        prompt_file_path = f"prompts/{file_name}"

        if content_value:
            logger.debug(f"Deleting old prompt file {prompt_file_path}")
            output_folder.delete_path(prompt_file_path)
            logger.debug(f"Creating new prompt in: {prompt_file_path}")
            output_folder.upload_data(prompt_file_path, content_value.encode("utf-8"))
        else:
            logger.warning("update_prompt called with empty prompt value")

        logger.debug(f"Processed prompt {prompt_file_path} successfully")
    except Exception as e:
        logger.exception(f"Error processing prompt {file_name}: {e}")


def update_prompt_temp(output_folder: Path, file_name: str, content_value: str) -> None:
    try:
        prompts_dir = output_folder / "prompts"
        prompts_dir.mkdir(parents=True, exist_ok=True)
        prompt_file_path = output_folder / "prompts" / file_name

        if content_value:
            # Ensure parent directories exist
            prompt_file_path.parent.mkdir(parents=True, exist_ok=True)

            if prompt_file_path.exists():
                logger.debug(f"Deleting old prompt file {prompt_file_path}")
                prompt_file_path.unlink()

            logger.debug(f"Creating new prompt in: {prompt_file_path}")
            prompt_file_path.write_text(content_value, encoding="utf-8")
        else:
            logger.warning("update_prompt called with empty prompt value")

        logger.debug(f"Processed prompt {prompt_file_path} successfully")

    except Exception as e:
        logger.exception(f"Error processing prompt {file_name}: {e}")


def get_graphrag_config(custom_config, root_dir):
    text_column = custom_config.get("text_column")
    source_column = custom_config.get("source_column")
    title_column = custom_config.get("title_column")
    attribute_columns = custom_config.get("attribute_columns", [])
    chunk_size = custom_config.get("chunk_size")
    chunk_overlap = custom_config.get("chunk_overlap")
    group_by_columns = custom_config.get("group_by_column", [])

    selected_columns = [text_column] + attribute_columns

    logger.debug(f"Selected columns for indexing: {selected_columns}")

    graph_rag_config = load_config(root_dir, None)

    if graph_rag_config.input:
        if selected_columns:
            graph_rag_config.input.document_attribute_columns = selected_columns
        if text_column:
            graph_rag_config.input.text_column = text_column
        if title_column:
            graph_rag_config.input.title_column = title_column
        if source_column:
            graph_rag_config.input.source_column = source_column
    if graph_rag_config.chunks:
        graph_rag_config.chunks.size = chunk_size
        graph_rag_config.chunks.overlap = chunk_overlap
        if group_by_columns and len(group_by_columns) > 0:
            graph_rag_config.chunks.group_by_columns = group_by_columns

    if custom_config.get("embed_graph_enabled", False):
        graph_rag_config.embed_graph.num_walks = custom_config.get("num_walks")
        graph_rag_config.embed_graph.walk_length = custom_config.get("walk_length")
        graph_rag_config.embed_graph.window_size = custom_config.get("window_size")
        graph_rag_config.embed_graph.iterations = custom_config.get("iterations")
        graph_rag_config.embed_graph.random_seed = custom_config.get("random_seed")

    graph_rag_config.entity_extraction.entity_types = custom_config.get("entity_types")

    stagger = float(custom_config.get("stagger"))
    num_threads = custom_config.get("num_threads")

    # TODO: Expose
    # graph_rag_config.embeddings.parallelization.batch_size =  16
    # graph_rag_config.embeddings.parallelization.batch_max_tokens = 8191

    graph_rag_config.parallelization.stagger = stagger
    graph_rag_config.parallelization.num_threads = num_threads

    graph_rag_config.embeddings.parallelization.stagger = stagger
    graph_rag_config.embeddings.parallelization.num_threads = num_threads

    graph_rag_config.entity_extraction.parallelization.stagger = stagger
    graph_rag_config.entity_extraction.parallelization.num_threads = num_threads

    graph_rag_config.summarize_descriptions.parallelization.stagger = stagger
    graph_rag_config.summarize_descriptions.parallelization.num_threads = num_threads

    graph_rag_config.community_reports.parallelization.stagger = stagger
    graph_rag_config.community_reports.parallelization.num_threads = num_threads

    graph_rag_config.claim_extraction.parallelization.stagger = stagger
    graph_rag_config.claim_extraction.parallelization.num_threads = num_threads

    logger.debug(f"graphrag config is {json.dumps(graph_rag_config.dict(), indent=2)}")
    return graph_rag_config
