import time
from abc import ABC, abstractmethod
from collections.abc import Generator
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union

from common.backend.constants import LLM_API_ERROR
from common.backend.models.base import (
    ConversationParams,
    LlmHistory,
    LLMStep,
    LLMStepDesc,
    MediaSummary,
    RetrievalSummaryJson,
)
from common.backend.services.sources.sources_builder import build_augmented_llm_or_agent_sources
from common.backend.utils.context_utils import TIME_FORMAT, LLMStepName
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.llm_utils import (
    add_history_to_completion,
    append_summaries_to_completion_msg,
    get_llm_capabilities,
    get_llm_completion,
    handle_response_trace,
    parse_error_messages,
)
from common.llm_assist.fallback import get_fallback_completion, is_fallback_enabled
from common.llm_assist.logging import logger
from dataiku.core.knowledge_bank import MultipartContext
from dataiku.langchain.dku_llm import DKULLM
from dataikuapi.dss.llm import (
    DSSLLMCompletionQuery,
    DSSLLMCompletionQueryMultipartMessage,
    DSSLLMCompletionResponse,
    DSSLLMStreamedCompletionChunk,
    DSSLLMStreamedCompletionFooter,
)
from dataikuapi.utils import DataikuException

DEFAULT_SYSTEM_PROMPT = (
    "\n The following is a helpful and professional conversation between a user and an assistant. "
    "Please assist with the following user query. Provide a clear and concise response. "
    "Focus on relevance and utility. Avoid speculation. \n"
)


class GenericAnswersChain(ABC):
    # Abstract class for Answers chains
    # Implement all abstract properties & methods in the child class
    def __init__(self):
        self._forced_non_streaming = False

    @property
    def forced_non_streaming(self):
        return self._forced_non_streaming

    @forced_non_streaming.setter
    def forced_non_streaming(self, value):
        if isinstance(value, bool):
            self._forced_non_streaming = value
        else:
            raise ValueError("forced_non_streaming must be a boolean")

    @property
    def webapp_config(self):
        return dataiku_api.webapp_config

    @property
    @abstractmethod
    def act_like_prompt(self) -> str:
        raise NotImplementedError("Subclasses must implement act_like_prompt property")

    @property
    @abstractmethod
    def system_prompt(self) -> str:
        raise NotImplementedError("Subclasses must implement system_prompt property")

    @property
    @abstractmethod
    def llm(self) -> DKULLM:
        raise NotImplementedError("Subclasses must implement llm property")

    @property
    @abstractmethod
    def chain_purpose(self) -> str:
        raise NotImplementedError("Subclasses must implement chain_purpose property")

    @abstractmethod
    def load_role_and_guidelines_prompts(self, params: ConversationParams):
        raise NotImplementedError("Subclasses must implement load_role_and_guidelines_prompts method")

    @abstractmethod
    def get_computed_system_prompt(self, params: ConversationParams) -> str:
        raise NotImplementedError("Subclasses must implement get_computed_system_prompt method")

    def prepare_computed_system_prompt(self, params: ConversationParams) -> str:
        self.load_role_and_guidelines_prompts(params)
        return self.get_computed_system_prompt(params)

    @abstractmethod
    def get_computing_prompt_step(self) -> LLMStep:
        raise NotImplementedError("Subclasses must implement get_computing_prompt_step method")

    @abstractmethod
    def get_querying_step(self, params: ConversationParams) -> LLMStep:
        raise NotImplementedError("Subclasses must implement get_querying_step method")

    @abstractmethod
    def finalize_streaming(
        self, params: ConversationParams, answer_context: Union[str, Dict[str, Any], List[str]]
    ) -> RetrievalSummaryJson:
        raise NotImplementedError("Subclasses must implement finalize_streaming method")

    @abstractmethod
    def finalize_non_streaming(
        self, params: ConversationParams, answer_context: Union[str, Dict[str, Any], List[str]]
    ) -> RetrievalSummaryJson:
        raise NotImplementedError("Subclasses must implement finalize_non_streaming method")

    @abstractmethod
    def get_retrieval_context(self, params: ConversationParams
    ) -> Tuple[Optional[Optional[Union[MultipartContext, str]]], Dict[str, Any]]:
        raise NotImplementedError("Subclasses must implement get_retrieval_context method")

    @abstractmethod
    def get_as_json(
        self,
        generated_answer: Union[str, Dict[str, Any], List[str]],
        user_profile: Optional[Dict[str, Any]] = None,
    ) -> RetrievalSummaryJson:
        logger.error("get_as_json method not implemented")
        return {}

    @abstractmethod
    def create_query_from_history_and_update_params(
        self, chat_history: List[LlmHistory], user_query: str, params: ConversationParams
    ) -> ConversationParams:
        return params

    def __run_non_streaming_query(
        self,
        params: ConversationParams,
        completion: DSSLLMCompletionQuery,
        answer_context: Dict[str, Any],
        first_attempt: bool = True
    ) -> Generator[Union[LLMStepDesc, RetrievalSummaryJson], Any, None]:
        start_time: str = datetime.now().strftime(TIME_FORMAT)
        global_start_time = params.get("global_start_time")
        if not global_start_time:
            raise Exception("global_start_time is not provided")
        response = ""
        error_message = ""
        step = self.get_querying_step(params) if first_attempt else LLMStep.USING_FALLBACK_LLM
        logger.debug({"step": step.name}, log_conv_id=True)
        yield {"step": step}
        try:
            if not first_attempt:
                completion = get_fallback_completion(completion)
            resp: DSSLLMCompletionResponse = completion.execute()
            handle_response_trace(resp)
            response = str(resp.text)
            agg_sources_list = build_augmented_llm_or_agent_sources(resp, completion.llm)
            answer_context["aggregated_sources_list"] = agg_sources_list
            if not resp.text and isinstance(resp._raw, dict) and resp._raw.get("errorMessage"):
                response = str(resp._raw.get("errorMessage"))
            logger.debug(f"""Time ===> taken by getting first chunk: {round((datetime.now() - datetime.strptime(start_time, TIME_FORMAT)).total_seconds(), 2)} secs
            Time ===> GLOBAL taken by getting first chunk: {(time.time() - global_start_time):.2f} secs
            """, log_conv_id=True)
            answer_context["answer"] = response
            yield self.finalize_non_streaming(params, answer_context)

        except Exception as e:
            error_message = LLM_API_ERROR
            logger.exception(f"{error_message}: {e}.", log_conv_id=True)
            # checking if we can use a fallback LLM
            fallback_enabled = is_fallback_enabled(completion.llm)
            if first_attempt and fallback_enabled:
                yield from self.__run_non_streaming_query(
                    params=params,
                    completion=completion,
                    answer_context=answer_context,
                    first_attempt=False
                )
            else:
                yield self.finalize_non_streaming(params, error_message)

    def __run_streaming_query(
        self, params: ConversationParams, completion: DSSLLMCompletionQuery, answer_context: Dict[str, Any], first_attempt: bool = True
    ) -> Generator[
        Union[
            DSSLLMStreamedCompletionChunk,
            DSSLLMStreamedCompletionFooter,
            LLMStepDesc,
            RetrievalSummaryJson,
        ],
        Any,
        None,
    ]:
        start_time: str = datetime.now().strftime(TIME_FORMAT)
        global_start_time = params.get("global_start_time")
        if not global_start_time:
            raise Exception(f"global_start_time is not provided")
        log_time = True
        if not first_attempt:
            yield {"step": LLMStep.USING_FALLBACK_LLM}
            completion = get_fallback_completion(completion)
        try:
            chain_purpose: str = params.get("chain_purpose") or LLMStepName.UNKNOWN.value
            for chunk in completion.execute_streamed():
                yield chunk
                if log_time:
                    if global_start_time:
                        logger.debug(f"""Time ===> taken by getting first chunk: {round((datetime.now() - datetime.strptime(start_time, TIME_FORMAT)).total_seconds(), 2)} secs
                        Time ===> GLOBAL taken by getting first chunk: {(time.time() - global_start_time):.2f} secs""", log_conv_id=True)
                    else:
                        logger.debug(
                            f"Time ===> taken by getting first chunk: {round((datetime.now() - datetime.strptime(start_time, TIME_FORMAT)).total_seconds(), 2)} secs", log_conv_id=True
                        )
                    log_time = False
                if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                    handle_response_trace(chunk)
                    if any(id_type in completion.llm.llm_id for id_type in ["retrieval-augmented-llm", "agent:"]):
                        agg_sources_list = build_augmented_llm_or_agent_sources(chunk, completion.llm)
                        answer_context["aggregated_sources_list"] = agg_sources_list

            yield {"step": LLMStep.STREAMING_END}
            result = self.finalize_streaming(params, answer_context)
            if result:
                yield result
        except Exception as e:
            error_message = self._parse_streaming_error_message(e, chain_purpose)
            logger.exception(f"{error_message}", log_conv_id=True)
            # checking if we can use a fallback LLM
            fallback_enabled = is_fallback_enabled(completion.llm)
            if first_attempt and fallback_enabled:
                fallback_streams = get_llm_capabilities(get_fallback=True).get("streaming", False)
                if fallback_streams:
                    yield from self.__run_streaming_query(
                        params=params,
                        completion=completion,
                        answer_context=answer_context,
                        first_attempt=False
                    )
                else:
                    yield from self.__run_non_streaming_query(
                        params=params,
                        completion=completion,
                        answer_context=answer_context,
                        first_attempt=False
                    )
            else:
                yield self.finalize_streaming(params, error_message)
                yield DSSLLMStreamedCompletionChunk({"text": error_message})


    def __update_completion_query(
        self, completion: DSSLLMCompletionQuery, params: ConversationParams, retrieved_context: Union[MultipartContext, str]
    ) -> DSSLLMCompletionQuery:
        media_summaries: List[MediaSummary] = params.get("media_summaries") or []
        previous_media_summaries: List[MediaSummary] = params.get("previous_media_summaries") or []
        summaries = media_summaries + previous_media_summaries
        if summaries:
            msg: DSSLLMCompletionQueryMultipartMessage = completion.new_multipart_message(role="user")
            append_summaries_to_completion_msg(summaries, msg)
        if retrieved_context:
            if isinstance(retrieved_context, str):
                logger.debug(f"Retrieved context is: {retrieved_context}", log_conv_id=True)
                completion.with_message(message=f"Retrieved source context:\n {retrieved_context}", role="user")
            elif isinstance(retrieved_context, MultipartContext):
                retrieved_context.add_to_completion_query(completion, role="user")
            else:
                error_message = f"Unsupported type for retrieved_context: {type(retrieved_context)}. Expected MultipartContext or str."
                logger.error(error_message, log_conv_id=True)
                raise ValueError(
                    error_message
                )
        return completion


    def _initialise_completion(self, params: ConversationParams) -> DSSLLMCompletionQuery:
        completion: DSSLLMCompletionQuery = get_llm_completion(self.llm)
        computed_system_prompt = self.prepare_computed_system_prompt(params)
        completion.with_message(message=computed_system_prompt, role="system")
        logger.debug(f"""Prompt: {computed_system_prompt if computed_system_prompt else "No prompt"}""", log_conv_id=True)
        completion = add_history_to_completion(completion, params.get("chat_history", []))
        return completion

    def run_completion_query(
        self, params: ConversationParams
    ) -> Generator[
        Union[
            LLMStepDesc,
            RetrievalSummaryJson,
            DSSLLMStreamedCompletionChunk,
            DSSLLMStreamedCompletionFooter,
        ],
        Any,
        None,
    ]:

        yield {"step": self.get_computing_prompt_step()}

        # TODO: name retrieved_context might have to change
        retrieved_context, answer_context = self.get_retrieval_context(params)
        # _initialise_completion must come after get_retrieval_context as
        # the system prompt can change depending on retrieval outcome
        completion: DSSLLMCompletionQuery = self._initialise_completion(params)
        completion = self.__update_completion_query(completion, params, retrieved_context)
        completion.with_message(params.get("user_query", ""), role="user")

        llm_capabilities = get_llm_capabilities()
        if llm_capabilities["streaming"] and not self.forced_non_streaming:
            yield from self.__run_streaming_query(
                params=params,
                completion=completion,
                answer_context=answer_context,
            )
        else:
            yield from self.__run_non_streaming_query(
                params=params, completion=completion, answer_context=answer_context
            )

    def load_default_role_and_guidelines_prompts(self) -> tuple:
        act_like_prompt = dataiku_api.webapp_config.get("primer_prompt", "")
        system_prompt = dataiku_api.webapp_config.get(
            "system_prompt",
            DEFAULT_SYSTEM_PROMPT,
        )
        return act_like_prompt, system_prompt

    def _parse_streaming_error_message(self, exception: Exception, chain_purpose: str) -> str:
        error_message = LLM_API_ERROR
        if isinstance(exception, DataikuException):
            logger.exception(f"LLM step {chain_purpose} failed with DataikuException: {exception}")
            message = parse_error_messages(exception)
            error_message = LLM_API_ERROR + " " + message
        return error_message
