"""
Business logic for trace operations.
Orchestrates parsing, repository, and external data source access.
"""
import json
import time
from typing import List, Optional

import dataiku

from traces_explorer.backend.models.trace import Trace, TokenUsage
from traces_explorer.backend.repositories.trace_repository import TraceRepository
from traces_explorer.backend.services.trace_parser import TraceParser


class TraceService:
    """
    Service layer for trace operations.
    Handles business logic and orchestrates repository and parser.
    """

    def __init__(self, repository: TraceRepository, parser: TraceParser, dataiku_api, logger):
        """
        Initialize trace service.

        Args:
            repository: TraceRepository for data access
            parser: TraceParser for parsing logic
            dataiku_api: Dataiku API client
            logger: Logger instance
        """
        self.repository = repository
        self.parser = parser
        self.dataiku_api = dataiku_api
        self.logger = logger

    def read_column_content(self, dataset_name: str, column_name: str) -> List[str]:
        """
        Read column data from Dataiku dataset.

        Args:
            dataset_name: Name of dataset
            column_name: Name of column

        Returns:
            List of column values as strings

        Raises:
            Exception: If dataset or column not found
        """
        try:
            dataset = dataiku.Dataset(dataset_name)
            data_frame = dataset.get_dataframe()
            return data_frame[column_name].tolist()
        except KeyError as e:
            self.logger.error(
                "Column '%s' not found in dataset '%s': %s",
                column_name,
                dataset_name,
                e,
            )
            raise Exception(
                f"Column '{column_name}' not found in dataset '{dataset_name}'"
            ) from e
        except Exception as e:
            self.logger.error("Error reading dataset '%s': %s", dataset_name, e)
            raise

    def load_traces(self) -> int:
        """
        Load traces from configured Dataiku dataset.
        Clears existing traces and loads fresh data.

        Returns:
            Number of traces loaded
        """
        config = self.dataiku_api.webapp_config
        llm_response_column = config.get("llm_responses_column")
        llm_response_dataset = config.get("llm_responses_dataset")

        # Clear existing traces
        self.repository.clear()

        # If no dataset or column configured, nothing to load
        if not llm_response_dataset or not llm_response_column:
            self.logger.info("No dataset or column configured, skipping trace loading")
            return 0

        traces_to_add: List[TraceDetail] = []

        try:
            responses = self.read_column_content(
                llm_response_dataset, llm_response_column
            )
        except Exception as e:
            self.logger.error("Failed to read traces from dataset: %s", e)
            return 0

        for response_index, response_str in enumerate(responses):
            try:
                parsed_traces = self.parser.parse_traces(response_str)

                for trace_index, (trace_name, trace_data) in enumerate(parsed_traces):
                    trace_id = f"{response_index}_{trace_index}"
                    trace_detail = self._create_trace_detail(
                        trace_id=trace_id,
                        trace_name=trace_name,
                        trace_data=trace_data,
                    )
                    traces_to_add.append(trace_detail)

            except json.JSONDecodeError as exc:
                self.logger.warning(  # FIX: use warning instead of deprecated warn
                    "Skipping response at index %d due to JSON parsing error: %s",
                    response_index,
                    exc,
                )
            except Exception as exc:
                self.logger.error(
                    "Unexpected error processing response at index %d: %s",
                    response_index,
                    exc,
                )

        # Add all traces in single write operation
        self.repository.add_all(traces_to_add)

        trace_count = len(traces_to_add)
        self.logger.info("Loaded %d traces", trace_count)
        return trace_count

    def _create_trace_detail(
        self,
        trace_id: str,
        trace_name: str,
        trace_data: dict,
        display_name_override: Optional[str] = None,
    ) -> Trace:
        """
        Create Trace from parsed trace data.

        Helper method to avoid duplication between load_traces and process_pasted_trace.

        Args:
            trace_id: Unique identifier for the trace
            trace_name: Name of the trace (from parsing)
            trace_data: Parsed trace data dictionary
            display_name_override: Optional display name to use instead of extracting

        Returns:
            Trace object with processed trace tree
        """
        # Extract display name (use override or extract from trace)
        if display_name_override:
            display_name = display_name_override
        else:
            display_name = self.parser.extract_display_name(trace_data)

        # Extract result
        result = self.parser.extract_output(trace_data)

        # Accumulate token usage
        tokens_dict = {
            "promptTokens": 0,
            "completionTokens": 0,
            "totalTokens": 0,
            "estimatedCost": 0.0,
        }
        self.parser.accumulate_usage_metadata(trace_data, tokens_dict)

        tokens = TokenUsage(
            prompt_tokens=tokens_dict["promptTokens"],
            completion_tokens=tokens_dict["completionTokens"],
            total_tokens=tokens_dict["totalTokens"],
            estimated_cost=tokens_dict["estimatedCost"],
        )

        # Process trace to add IDs (creates deep copy)
        processed_node = self.parser.process_trace_add_ids(trace_data)

        # Create Trace object
        return Trace(
            id=trace_id,
            name=display_name,
            result=result,
            begin=trace_data.get("begin"),
            duration=trace_data.get("duration"),
            trace_name=trace_name,
            tokens=tokens,
            parent_node=processed_node,
        )

    def list_traces(self) -> List[Trace]:
        """
        List all traces.

        Returns:
            List of Trace objects
        """
        return self.repository.list_all()

    def get_trace(self, trace_id: str) -> Optional[Trace]:
        """
        Get trace by ID (full details).

        Args:
            trace_id: Trace ID to retrieve

        Returns:
            Trace if found, None otherwise
        """
        return self.repository.get_by_id(trace_id)

    def process_pasted_trace(self, trace_json: str, name: Optional[str] = None) -> Trace:
        """
        Process a pasted trace JSON string without storing in repository.

        Args:
            trace_json: Raw JSON string containing trace data
            name: Optional display name override

        Returns:
            Trace object with processed trace tree

        Raises:
            ValueError: If JSON is invalid or no trace found
        """
        try:
            # Parse JSON string
            parsed_traces = self.parser.parse_traces(trace_json)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON: {str(e)}") from e

        # Take first trace only
        if not parsed_traces:
            raise ValueError("No valid trace object found in JSON")

        trace_name, trace_data = parsed_traces[0]

        # Generate timestamp-based ID (milliseconds)
        trace_id = f"LOCAL_PASTED_TRACE_{int(time.time() * 1000)}"

        # Create and return TraceDetail using helper
        return self._create_trace_detail(
            trace_id=trace_id,
            trace_name=trace_name,
            trace_data=trace_data,
            display_name_override=name,
        )
