import ast
import json
import re
from typing import Any, Dict, List, Optional, Set, Union

from common.llm_assist.logging import logger

REPLACEMENT_VALUE = "...removed_in_logs"


def try_parse_json(data: str) -> Any:
    try:
        return json.loads(data)
    except json.JSONDecodeError:
        return None


def extract_json(response_text: str, json_pattern: Optional[str]) -> Dict[str, Any]:
    # Find all characters that could be the start or end of a JSON object
    json_objects = re.findall(json_pattern if json_pattern else r"{[\s\S]*?}", response_text)
    # Find the longest string that could be a JSON object,
    # since it's most likely to be the correct one
    longest_json = max(json_objects, key=len, default="{}")
    # Convert the string back to a dictionary (JSON object)
    json_data = try_parse_json(longest_json)
    if json_data is None:
        logger.exception("Error parsing json")
        # In case JSON decoding fails, try decoding the str or return an empty dictionary
        # TODO - This should be handled better maybe no more need to extract longest json
        json_data = try_parse_json(response_text)
    return json_data or {}


def mask_keys_in_json(data: Any, keys_to_mask: Set[str]) -> Any:
    """
    Internal helper that recursively masks specified keys in dicts/lists.
    Supports nested dicts, lists, and Python-style dict strings.
    """
    if isinstance(data, dict):
        return {
            k: REPLACEMENT_VALUE if k in keys_to_mask else mask_keys_in_json(v, keys_to_mask) for k, v in data.items()
        }

    elif isinstance(data, list):
        masked_list = []
        for item in data:
            if isinstance(item, str):
                # Try to parse stringified Python dict
                try:
                    parsed = ast.literal_eval(item)
                    if isinstance(parsed, dict):
                        masked = mask_keys_in_json(parsed, keys_to_mask)
                        masked_list.append(str(masked))  # Convert back to string
                        continue
                except Exception:
                    pass
            masked_list.append(mask_keys_in_json(item, keys_to_mask))
        return masked_list

    elif isinstance(data, str):
        # Check if string is a JSON object
        try:
            parsed = json.loads(data)
            return mask_keys_in_json(parsed, keys_to_mask)
        except json.JSONDecodeError:
            return data  # Leave as-is if not JSON

    return data


def try_literal_eval(data: str) -> Any:
    try:
        return ast.literal_eval(data)
    except Exception:
        return None


def coerce_to_dict(data: Union[str, dict]) -> Any:
    if isinstance(data, dict):
        return data
    parsed = try_parse_json(data) or try_literal_eval(data)
    return parsed if parsed else data

def mask_keys(data: Any, keys_to_mask: List[str]) -> Any:
    """
    Masks specified keys in a dict, list, or JSON string.
    Returns the masked structure or the original data if parsing or masking fails.
    """
    keys_set = set(keys_to_mask)

    if isinstance(data, str):
        parsed = try_parse_json(data) or try_literal_eval(data)
        if parsed is not None:
            data = parsed
        else:
            return data  # Couldn’t parse the string

    try:
        return mask_keys_in_json(data, keys_set)
    except Exception:
        return data  # Couldn’t apply masking


def load_json(data: Optional[str], default_value: Any=None):
    return json.loads(data) if data else default_value