import json
import re
from typing import Any, Dict, List, Union

import dataiku
from common.backend.constants import DEFAULT_MAX_LLM_TOKENS, DEFAULT_TEMPERATURE, LOWEST_TEMPERATURE
from common.backend.models.base import (
    LlmHistory,
    MediaSummary,
    UploadChainTypes,
)
from common.backend.utils.config_utils import resolve_webapp_param
from common.backend.utils.context_utils import add_llm_step_trace
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.file_utils import file_path_text_parts, file_path_to_image_parts
from common.llm_assist.llm_api_handler import llm_setup
from common.llm_assist.logging import logger
from dataiku.langchain.dku_llm import DKULLM
from dataikuapi.dss.llm import (
    DSSLLMCompletionQuery,
    DSSLLMCompletionQueryMultipartMessage,
    DSSLLMCompletionResponse,
    DSSLLMImageGenerationQuery,
    DSSLLMImageGenerationResponse,
    DSSLLMStreamedCompletionFooter,
)
from dataikuapi.utils import DataikuException

webapp_config: Dict[str, Any] = dataiku_api.webapp_config

FALLBACK_LLM_KEY = "ENABLE_FALLBACK_LLM"

def compare_versions(version1, version2):
    def parse_version(version):
        try:
            return [int(part) for part in version.split(".")]
        except ValueError:
            return None  # Malformed version

    parts1 = parse_version(version1)
    if parts1 is None:
        # If version1 is malformed, treat it as a dev version
        return 1

    # Parse version2 (always well-formed)
    parts2 = [int(part) for part in version2.split(".")]

    # Compare corresponding parts
    for part1, part2 in zip(parts1, parts2):
        if part1 < part2:
            return -1  # version1 is older
        elif part1 > part2:
            return 1  # version1 is newer

    # If all parts compared are equal, compare lengths, eg. 13.0 vs 13.0.1
    if len(parts1) < len(parts2):
        return -1  # version1 is older (shorter version)
    elif len(parts1) > len(parts2):
        return 1  # version1 is newer (longer version)

    # Versions are identical
    return 0


# Image generation is supported with the following providers:

# OpenAI (DALL-E 3)

# Azure OpenAI (DALL-E 3)

# Google Vertex (Imagen 1 and Imagen 2)

# Stability AI (Stable Image Core, Stable Diffusion 3.0, Stable Diffusion 3.0 Turbo)

# Bedrock Titan Image Generator

# Bedrock Stable Diffusion XL 1


def get_llm_capabilities(get_fallback: bool = False) -> Dict[str, bool]:
    config: Dict[str, str] = dataiku_api.webapp_config
    if get_fallback:
        from common.llm_assist.fallback import get_fallback_id
        llm_id = get_fallback_id()
        if llm_id is None:
            return {"multi_modal": False, "streaming": False, "image_generation": False}
    else:
        llm_id = config["llm_id"]
    force_streaming_mode = bool(config.get("force_streaming_mode", False))
    force_multi_modal_mode = bool(config.get("force_multi_modal_mode", False))
    multi_modal, streaming, image_generation = force_multi_modal_mode, force_streaming_mode, False
    client = dataiku.api_client()
    dss_version = client.get_instance_info().raw.get("dssVersion", "0.0.0")
    if dss_version == "0.0.0":
        logger.warn("Could not retrieve DSS version")
    # Split the llm_id to extract the connection type and model
    parts = llm_id.split(":")
    if len(parts) >= 3:
        connexion, _, model = parts[:3]
        streaming = force_streaming_mode or (
            (connexion == "openai" and model.startswith("gpt"))
            or (connexion == "bedrock" and any(prefix in model for prefix in ["amazon.titan-", "anthropic.claude-"]))
            or (connexion == "azureopenai" and compare_versions(dss_version, "12.6.2") >= 0)
        )
        multi_modal = force_multi_modal_mode or (
            (connexion == "openai" and model.startswith("gpt-4"))
            or (connexion == "vertex" and model == "gemini-pro-vision")
            or (
                connexion == "bedrock"
                and model.startswith("anthropic.claude-3")
                and compare_versions(dss_version, "13.0.2") >= 0
            )
        )
        image_generation = (
            False
            if compare_versions(dss_version, "13.0.0") < 0 or not config.get("enable_image_generation", False)
            else config.get("image_generation_llm_id", "") != ""
        )
    if compare_versions(dss_version, "12.5.0") >= 0:
        return {
            "multi_modal": multi_modal,
            "streaming": streaming,
            "image_generation": image_generation,
        }
    else:
        return {"multi_modal": False, "streaming": False, "image_generation": False}


def handle_prompt_media_explanation(system_prompt: str, has_media: bool) -> str:
    if has_media:
        logger.debug("Appending media explanation to system prompt")
        example = """
        -- Start of example --
            [{"role": "user", "content":"hello"},
            {"role": "assistant", "content":"How can I help you today?"},
            {"role": "user", "content":"generate blue circle"},
            {"role": "assistant", "content":'{"generated_media_by_ai": {"images": [{"file_path": "userwx_17117_04RMXS.png", "file_format": "png", "referred_file_path":""}]}}'}،
            {"role": "user", "content": "Thank you"},]
            Expected Answer: You are welcome
        -- End of example --
        """
        system_prompt = f"""{system_prompt}. 
            During the current conversation, some media such as images could have been generated by an image generation agent.
            In that case Chat history could include metadata about the media generated by the image generation agent in the form of a generated_media_by_ai JSON object in the assistant message.
            When responding, do not include generated_media_by_ai object in your answers to the user.
            You do not have access to the media itself. If the user asks about the media, you can inform them that you don't have access to it.
            You can ignore the media in the current conversation and continue with your tasks.
            # Example :"""
        system_prompt = f"{system_prompt}\n{example}"
    return system_prompt


def append_summaries_to_completion_msg(
    media_summaries: List[MediaSummary], msg: DSSLLMCompletionQueryMultipartMessage
) -> DSSLLMCompletionQueryMultipartMessage:
    try:
        logger.debug("Appending media summaries to completion message")
        is_first_text = True
        for summary in media_summaries:
            chain_type: Union[str, None] = summary.get("chain_type")

            if chain_type in [
                UploadChainTypes.IMAGE.value,
                UploadChainTypes.DOCUMENT_AS_IMAGE.value,
            ]:
                msg = file_path_to_image_parts(summary, msg)
            elif chain_type == UploadChainTypes.SHORT_DOCUMENT.value:
                if is_first_text:
                    msg.with_text("""The user uploaded file(s) with extracted text along with their query. Here is the extracted text:
                    """)
                msg = file_path_text_parts(summary, msg)
                is_first_text = False
            else:
                continue
        msg.add()
    except Exception as e:
        logger.exception(f"Error when creating completion query from media summaries: {e}")


def get_llm_completion(llm: DKULLM) -> DSSLLMCompletionQuery:
    completion: DSSLLMCompletionQuery = dataiku_api.default_project.get_llm(llm.llm_id).new_completion() # Replacement of llm._llm_handle.new_completion() to prevent the use of private fields
    completion.settings["maxOutputTokens"] = llm.max_tokens if llm.max_tokens else DEFAULT_MAX_LLM_TOKENS
    if llm.temperature is not None: # The temperature is passed as a completion parameter only if it has been resolved without a 'None' value
        completion.settings["temperature"] = llm.temperature
        if llm.temperature > LOWEST_TEMPERATURE:
            logger.warn(f"The LLM '{llm.llm_id}' temperature '{llm.temperature}' is '>= {LOWEST_TEMPERATURE if llm.temperature < 1.0 else 1.0 }'.")
    logger.debug(f"completion settings for LLM '{llm.llm_id}': {completion.settings})")
    return completion


def get_llm_image_generation(llm: DKULLM) -> DSSLLMImageGenerationQuery:
    llm_id = llm.llm_id
    completion = dataiku_api.default_project.get_llm(llm_id).new_images_generation()
    logger.debug(f"Generation settings are ready for the LLM: '{llm_id}'")
    return completion


def extract_response_trace(response: Union[
        DSSLLMCompletionResponse,
        DSSLLMStreamedCompletionFooter,
        DSSLLMImageGenerationResponse,
    ]
) -> Dict[str, Any]:
    try:
        trace_value = {}
        if hasattr(response, "trace"):
            trace_value = getattr(response, "trace")
        elif isinstance(response, dict) and "trace" in response:
            trace_value = response["trace"]
        return trace_value
    except Exception as e:
        logger.exception(f"Error when handling response trace: {e}")
        return {}


def handle_response_trace(
    response: Union[
        DSSLLMCompletionResponse,
        DSSLLMStreamedCompletionFooter,
        DSSLLMImageGenerationResponse,
    ]
) -> None:
    trace_value = extract_response_trace(response)
    if trace_value:
        add_llm_step_trace(trace_value)

def parse_error_messages(error_as_str: Union[str, DataikuException]) -> str:
    dicts = re.findall(r"\{[^{}]*\}", str(error_as_str))
    if len(dicts) > 0:
        message = json.loads(dicts[0])
        if "message" in message:
            return f" Error message: {message['message']}"
    logger.debug(f"Error message from LLM couldn't be parsed: {error_as_str}")
    return ""


def get_alternative_llm(llm_id_key: str) -> DKULLM:
    if llm_id := dataiku_api.webapp_config.get(llm_id_key):
        logger.debug(f"Using alternative LLM ID: {llm_id}")
        if llm_id_key == "title_llm_id":
            use_advanced_llm_parameters = webapp_config.get("use_advanced_title_llm_settings", False) or False
            temperature = resolve_webapp_param("title_llm_temperature", default_value=DEFAULT_TEMPERATURE, advanced_mode_enabled=use_advanced_llm_parameters)
            max_tokens = resolve_webapp_param("max_title_llm_tokens", default_value=DEFAULT_MAX_LLM_TOKENS, advanced_mode_enabled=use_advanced_llm_parameters)
        elif llm_id_key == "json_decision_llm_id":
            use_advanced_llm_parameters = webapp_config.get("use_advanced_decision_llm_settings", False) or False
            temperature = resolve_webapp_param("decision_llm_temperature", default_value=DEFAULT_TEMPERATURE, advanced_mode_enabled=use_advanced_llm_parameters)
            max_tokens = resolve_webapp_param("max_decision_llm_tokens", default_value=DEFAULT_MAX_LLM_TOKENS, advanced_mode_enabled=use_advanced_llm_parameters)
            if isinstance(temperature, (float, int)) and (temperature > LOWEST_TEMPERATURE):
                logger.warn(f"The 'Decisions LLM' temperature is not set to the minimum value ('{LOWEST_TEMPERATURE}'): current value: '{temperature}'. It must be as close to '{LOWEST_TEMPERATURE}' as possible.")
            if max_tokens == DEFAULT_MAX_LLM_TOKENS:
                logger.warn(f"The 'Decisions LLM' max output tokens is set to the minimum allowed value ('{DEFAULT_MAX_LLM_TOKENS}'): It is recommended to set a high value for accurate results.")
        else:
            logger.debug(f"The  alternative LLM ID: '{llm_id}' will be used with the default parameters (temperature={DEFAULT_TEMPERATURE}, max_tokens={DEFAULT_MAX_LLM_TOKENS}).")
            temperature = DEFAULT_TEMPERATURE
            max_tokens = DEFAULT_MAX_LLM_TOKENS
        dku_llm = DKULLM(llm_id=llm_id, max_tokens=max_tokens)
        dku_llm.temperature = temperature # We can't set a 'None' temperature value in the DKULLM constructor before DSS 14
        return dku_llm
    elif llm_id := dataiku_api.webapp_config.get("llm_id"):
        logger.debug(f"As the LLM '{llm_id_key}' is not set the 'Main LLM' ('{llm_id}') will be used")
        return llm_setup.get_llm()
    raise Exception("No LLM ID found in webapp config")


def add_history_to_completion(
    completion: DSSLLMCompletionQuery,
    chat_history: List[LlmHistory],
) -> DSSLLMCompletionQuery:
    for hist_item in chat_history:
        if input_ := hist_item.get("input"):
            completion.with_message(message=input_, role="user")
        if output := hist_item.get("output"):
            completion.with_message(message=output, role="assistant")
    return completion


def get_llm_friendly_name(llm_id: str, project_key: str) -> str:
    project = dataiku.api_client().get_project(project_key)
    llms: List[Dict[str, str]] = project.list_llms()

    for llm in llms:
        if llm.get('id') == llm_id:
            return llm.get('friendlyName', llm_id)
    return ""