from __future__ import annotations

import logging
import math
import threading
import time
from collections import defaultdict
from typing import Any, DefaultDict, Dict, Generator, List, Set

from dataiku import Dataset, api_client
from pandas import DataFrame
from typing_extensions import TypedDict

from editor.backend.utils.filtering import apply_filter_rules
from editor.backend.utils.webapp_config import webapp_config

from ..models import EdgeGroupDefinition, NodeGroupDefinition, Sampling

logger = logging.getLogger(__name__)


class NodesBatch(TypedDict):
    definition: NodeGroupDefinition
    df: DataFrame


class EdgesBatch(TypedDict):
    definition: EdgeGroupDefinition
    df: DataFrame

CHUNK_SIZE = 10000


def __get_node_relevant_cols__(definition: NodeGroupDefinition) -> List[str]:
    return list(
        set(
            [definition["primary_col"]]
            + [definition["label_col"]]
            + definition["property_list"]
            + [f["column"] for f in definition["filters_stored"]]
        )
    )


def __get_edge_relevant_cols_ordered__(definition: EdgeGroupDefinition) -> List[str]:
    # Order is very important because of the way kuzu handles the insertion using COPY FROM csv.
    # First column not in the schema is considered the source and second column not in the schema is the target.
    ordered_set: Dict[str, None] = {}
    for col in (
        [definition["source_column"]]
        + [definition["target_column"]]
        + definition["property_list"]
        + [f["column"] for f in definition["filters_stored"]]
    ):
        ordered_set[col] = None
    return list(ordered_set.keys())


def __try_get_count_records(dataset: Dataset) -> int:
    try:
        return dataset.get_last_metric_values().get_metric_by_id("records:COUNT_RECORDS")["lastValues"][0]["value"]  # type: ignore
    except:
        return -1
        

def __trigger_metrics_computation_for_datasets(datasets: List[str]) -> None:
    """
    Triggers metrics computation by spawning a new thread for each dataset (non-blocking).
    """
    project_key = webapp_config.default_project_key
    
    threads = []
    for dataset_name in datasets:
        # Create a thread for each dataset, targeting our worker function
        thread = threading.Thread(
            target=_task_compute_single_dataset_count_metrics,
            args=(dataset_name, project_key)
        )
        threads.append(thread)
        thread.start()


def _task_compute_single_dataset_count_metrics(dataset_name: str, project_key: str):
    """
    Worker function to compute metrics for a single dataset.
    """
    try:
        project = api_client().get_project(project_key)
        dss_dataset_object = project.get_dataset(dataset_name)
        
        dss_dataset_object.compute_metrics(metric_ids=["records:COUNT_RECORDS"])

    except Exception as e:
        logger.error(f"[Thread-{threading.get_ident()}] Failed to compute metrics for '{dataset_name}': {str(e)}")


def generate_nodes_batches(
    group_definitions: List[NodeGroupDefinition], sampling: Sampling | None
) -> Generator[NodesBatch, Any, None]:
    # Pre-group node definitions by dataset so we can iterate over a dataset only once.
    dataset_group_map: Dict[str, List[NodeGroupDefinition]] = defaultdict(list)
    columns_group_map: Dict[str, Set[str]] = defaultdict(set)
    for definition in group_definitions:
        source_dataset = definition["source_dataset"]
        dataset_group_map[source_dataset].append(definition)
        for col in __get_node_relevant_cols__(definition):
            columns_group_map[source_dataset].add(col)

    for source_dataset in dataset_group_map.keys():
        columns = sorted(columns_group_map[source_dataset])

        dataset = Dataset(source_dataset, webapp_config.default_project_key)
        count_records = __try_get_count_records(dataset)

        chunk_size = CHUNK_SIZE
        if sampling and sampling["sampling"] != "all":

            def sample_generator():
                return dataset.iter_dataframes(
                    columns=columns,
                    sampling=sampling["sampling"],
                    limit=sampling["max_rows"],
                    chunksize=chunk_size,
                    infer_with_pandas=False,
                )

            generator_func = sample_generator
        else:

            def whole_generator():
                return dataset.iter_dataframes(columns=columns, infer_with_pandas=False)

            generator_func = whole_generator

        definition_columns_to_keep: Dict[str, List[str]] = {
            definition["definition_id"]: __get_node_relevant_cols__(definition)
            for definition in dataset_group_map[source_dataset]
        }

        count_processed = 0
        for df in generator_func():
            for definition in dataset_group_map[source_dataset]:
                start_time = time.time()
                group_df = apply_filter_rules(
                    dataframe=df,
                    rules=definition["filters_stored"],
                    filter_association=definition["filters_association"],
                )
                filter_time = time.time()
                columns_to_keep = definition_columns_to_keep[definition["definition_id"]]
                if len(group_df):
                    yield NodesBatch(definition=definition, df=group_df[list(columns_to_keep)].copy())

                logger.debug(f"Generate batch timings: filter {round(filter_time - start_time, 5)}.")

            count_processed += chunk_size
            logger.debug(
                f"Generating nodes: processed {count_processed}/{count_records} for dataset '{source_dataset}'."
            )


def generate_edges_batches(
    group_definitions: List[EdgeGroupDefinition], sampling: Sampling | None
) -> Generator[EdgesBatch, Any, None]:
    # Pre-group edges by dataset so we can iterate over a dataset only once.
    dataset_group_map: Dict[str, List[EdgeGroupDefinition]] = defaultdict(list)
    columns_group_map: Dict[str, Set[str]] = defaultdict(set)
    for definition in group_definitions:
        edge_dataset = definition["edge_dataset"]
        dataset_group_map[edge_dataset].append(definition)

        for col in __get_edge_relevant_cols_ordered__(definition):
            columns_group_map[edge_dataset].add(col)

    for source_dataset in dataset_group_map.keys():
        columns = columns_group_map[source_dataset]

        dataset = Dataset(source_dataset, project_key=webapp_config.default_project_key)
        count_records = __try_get_count_records(dataset)

        chunk_size = CHUNK_SIZE
        if sampling and sampling["sampling"] != "all":

            def sample_generator():
                return dataset.iter_dataframes(
                    columns=columns,
                    sampling=sampling["sampling"],
                    limit=sampling["max_rows"],
                    chunksize=chunk_size,
                    infer_with_pandas=False,
                )

            generator_func = sample_generator
        else:

            def full_generator():
                return dataset.iter_dataframes(columns=columns, infer_with_pandas=False)

            generator_func = full_generator

        definition_columns_to_keep: Dict[str, List[str]] = {
            definition["definition_id"]: __get_edge_relevant_cols_ordered__(definition)
            for definition in dataset_group_map[source_dataset]
        }

        count_processed = 0
        for df in generator_func():
            for definition in dataset_group_map[source_dataset]:
                start_time = time.time()
                group_df = apply_filter_rules(
                    dataframe=df,
                    rules=definition["filters_stored"],
                    filter_association=definition["filters_association"],
                )
                filter_time = time.time()
                columns_to_keep = definition_columns_to_keep[definition["definition_id"]]
                if len(group_df):
                    yield EdgesBatch(definition=definition, df=group_df[list(columns_to_keep)].copy())

                logger.debug(f"Generate batch timings: filter {round(filter_time - start_time, 5)}.")

            count_processed += chunk_size
            logger.debug(f"Generating edges from '{source_dataset}': {count_processed}/{count_records}.")

def calculate_estimated_total_batches(
    nodes_definitions: List[NodeGroupDefinition], 
    edges_definitions: List[EdgeGroupDefinition], 
    sampling: Sampling | None
) -> int:
    """
    Calculate estimated total batches using the heuristic approach.
    
    This method:
    1. Groups definitions by dataset
    2. Gets record counts for each dataset (or estimates based on sampling)
    3. Calculates batches per dataset: chunks_per_dataset * definitions_using_dataset
    
    Returns:
        int: Total estimated batches that will be processed.
             Returns -1 if metrics are not available and need to be computed.
    """
    
    # Count how many node definitions use each dataset
    node_datasets_count: DefaultDict[str, int] = defaultdict(int)
    for node_definition in nodes_definitions:
        node_datasets_count[node_definition["source_dataset"]] += 1
    
    # Count how many edge definitions use each dataset  
    edge_datasets_count: DefaultDict[str, int] = defaultdict(int)
    for edge_definition in edges_definitions:
        edge_datasets_count[edge_definition["edge_dataset"]] += 1
    
    # Get all unique datasets
    all_unique_datasets = list(set(list(node_datasets_count.keys()) + list(edge_datasets_count.keys())))
    
    dataset_record_counts = {}
    datasets_without_metrics = []
    
    for dataset_name in all_unique_datasets:
        dataset = Dataset(dataset_name, webapp_config.default_project_key)
        record_count = __try_get_count_records(dataset)
        if record_count == -1:
            datasets_without_metrics.append(dataset_name)
        else:
            dataset_record_counts[dataset_name] = record_count
    
    # If any datasets don't have metrics, trigger computation and return -1
    if datasets_without_metrics:
        __trigger_metrics_computation_for_datasets(datasets_without_metrics)
        return -1
    
    # Calculate total batches - all datasets have metrics available
    total_estimated_batches = 0
    
    # Calculate node batches
    for dataset_name, definition_count in node_datasets_count.items():
        record_count = dataset_record_counts[dataset_name]
        
        # Calculate effective rows
        if sampling and sampling["sampling"] != "all":
            max_rows = int(sampling["max_rows"]) if isinstance(sampling["max_rows"], str) else sampling["max_rows"]
            effective_rows = min(int(record_count), max_rows)
        else:
            effective_rows = int(record_count)
        
        chunks_per_dataset = math.ceil(effective_rows / CHUNK_SIZE)
        dataset_batches = chunks_per_dataset * definition_count
        total_estimated_batches += dataset_batches
    
    # Calculate edge batches
    for dataset_name, definition_count in edge_datasets_count.items():
        record_count = dataset_record_counts[dataset_name]
        
        # Calculate effective rows
        if sampling and sampling["sampling"] != "all":
            max_rows = int(sampling["max_rows"]) if isinstance(sampling["max_rows"], str) else sampling["max_rows"]
            effective_rows = min(int(record_count), max_rows)
        else:
            effective_rows = int(record_count)
        
        chunks_per_dataset = math.ceil(effective_rows / CHUNK_SIZE)
        dataset_batches = chunks_per_dataset * definition_count
        total_estimated_batches += dataset_batches
    
    return total_estimated_batches