import concurrent.futures
import itertools
import json
import time
import uuid
from enum import Enum
from typing import Dict, Generator, List, Optional, Union

import pandas as pd
from common.backend.constants import CONVERSATION_DEFAULT_NAME
from common.backend.db.sql.queries import (
    CreateIndexQueryBuilder,
    DeleteQueryBuilder,
    UpdateQueryBuilder,
    WhereCondition,
    get_post_queries,
)
from common.backend.db.sql.tables_managers import GenericLoggingDatasetSQL
from common.backend.models.base import (
    Conversation,
    ConversationInfo,
    ConversationType,
    Feedback,
    LlmHistory,
    MediaSummary,
    QuestionData,
)
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.dataiku_utils import get_llm_friendly_name
from common.backend.utils.rag_utils import load_images_data
from common.backend.utils.sql_timing import log_query_time
from common.llm_assist.logging import logger
from dataiku.sql.expression import Operator
from portal.backend.models import PortalQuestionData, PortalSource, ToolCall
from werkzeug.exceptions import BadRequest

DB_NAME_CONF_ID = "logging_dataset"


class RecordState(Enum):
    PRESENT = "present"
    DELETED = "deleted"
    CLEARED = "cleared"


COLUMNS = [
    "conversation_id",
    "conversation_name",
    "llm_name",
    "user",
    "message_id",
    "question",
    "answer",
    "filters",
    "sources",
    "feedback_value",
    "feedback_choice",
    "feedback_message",
    "timestamp",
    "state",
    "llm_context",
    "generated_media",
]


class ConversationSQL(GenericLoggingDatasetSQL):
    def __init__(self):
        config = dataiku_api.webapp_config
        super().__init__(
            config=config,
            columns=COLUMNS,
            dataset_conf_id=DB_NAME_CONF_ID,
            default_project_key=dataiku_api.default_project_key,
        )

        self.permanent_delete = self.config.get("permanent_delete", True)
        self.__init_indexes()
        self._init_thread_executor()

    def __init_indexes(self) -> None:
        index_chat_history: bool = bool(self.config.get("index_chat_history", False))
        if not index_chat_history:
            return
        logger.debug("Creating indexes")
        indexes = [
            {"name": "user_index", "columns": ["user"]},
            {"name": "conversation_id_index", "columns": ["conversation_id"]},
            {"name": "timestamp_index", "columns": ["timestamp"]},
        ]

        for index in indexes:
            try:
                query = (
                    CreateIndexQueryBuilder(self.dataset)
                    .set_index_name(index["name"])
                    .add_columns(index["columns"])
                    .build()
                )
                self.executor.query_to_df(query, post_queries=get_post_queries(self.dataset))
                logger.debug(f"Index {index['name']} created")
            except Exception:
                logger.debug(f"Couldn't create index {index['name']} because Index may already exist. Proceeding...")
        return

    def _init_thread_executor(self):
        self.thread_executor = None
        max_thread_workers_number = self.config.get("max_thread_workers_number", 0)
        if max_thread_workers_number:
            logger.debug(f"Creating a new ThreadPoolExecutor with max_workers = {max_thread_workers_number}")
            self.thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_thread_workers_number)

    @log_query_time
    def get_user_conversations(self, auth_identifier: str) -> List[ConversationInfo]:
        column_names = [
            self.col("user"),
            self.col("conversation_id"),
            self.col("conversation_name"),
            self.col("timestamp"),
            self.col("state"),
        ]
        order_by_column = self.col("timestamp")
        eq_cond = [
            WhereCondition(column=self.col("user"), operator=Operator.EQ, value=auth_identifier),
            WhereCondition(
                column=self.col("state"),
                operator=Operator.OR,
                value=[RecordState.PRESENT.value, RecordState.CLEARED.value],
            ),
        ]
        format_ = "dataframe"
        result: pd.DataFrame = self.select_columns_from_dataset(
            column_names=column_names,
            distinct=True,
            eq_cond=eq_cond,
            format_=format_,
            order_by=order_by_column,
        )
        conversations: Dict[str, ConversationInfo] = {}
        for index, row in result.iterrows():
            conversation_id = row[self.col("conversation_id")]
            if conversation_id not in conversations or row[self.col("timestamp")] < conversations[conversation_id]["timestamp"]:
                conversations[conversation_id] = ConversationInfo(
                    id=conversation_id,
                    name=row[self.col("conversation_name")],
                    timestamp=row[self.col("timestamp")],
                )
        return list(conversations.values())

    @log_query_time
    def get_conversation(self, auth_identifier: str, conversation_id: str, only_present: bool = True):
        eq_cond = [
            WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
            WhereCondition(column=self.col("conversation_id"), value=conversation_id, operator=Operator.EQ),
        ]
        if only_present:
            eq_cond.append(
                WhereCondition(column=self.col("state"), value=RecordState.PRESENT.value, operator=Operator.EQ)
            )
        format_ = "dataframe"
        order_by = self.col("message_id")
        result: pd.DataFrame = self.select_columns_from_dataset(
            column_names=self.columns,
            eq_cond=eq_cond,
            format_=format_,
            order_by=order_by,
        )
        return self.convert_query_result_to_conversation(result, only_present=True)

    def sql_query_to_df_with_error_handling(self, sql_query):
        try:
            return self.executor.query_to_df(sql_query, post_queries=get_post_queries(self.dataset))
        except Exception as e:
            logger.exception(f"Error when executing SQL query: {e}", log_conv_id=True)

    def _execute_async_sql_tasks(self, sql_query):
        if self.thread_executor is None:
            self.sql_query_to_df_with_error_handling(sql_query)
        else:
            # Submit the task to the executor
            self.thread_executor.submit(self.sql_query_to_df_with_error_handling, sql_query)
            logger.debug("Task submitted to the executor", log_conv_id=True)

    @log_query_time
    def add_record(  # noqa: PLR0917 too many positional arguments
        self,
        record: QuestionData,
        auth_identifier: str,
        conversation_id: Optional[str],
        conversation_name: Optional[str] = CONVERSATION_DEFAULT_NAME,
        is_new_conversation: Optional[bool]=True,
        llm_id: Optional[str] = None,
    ):
        if is_new_conversation:
            logger.debug("New conversation to create", log_conv_id=True)
            if conversation_id is None:
                conversation_id = str(uuid.uuid4())
                logger.debug(f"The conversation_id was not set: the conversation will be saved with the ID '{conversation_id}'", log_conv_id=True)
            record_id = str(0)
        else:
            record_id = str(record["id"])
        logger.debug(f"Adding the conversation turn n°'{record_id}' to the conversation_id '{conversation_id}'", log_conv_id=True)

        llm_context = record.get("llm_context", {})
        generated_media = record.get("generated_media", {})
        timestamp = str(time.time())
        record_value = [
            conversation_id,
            conversation_name,
            get_llm_friendly_name(llm_id),
            auth_identifier,
            record_id,
            record["query"],
            record["answer"],
            None,  # json.dumps({"filters": {}}, ensure_ascii=False),
            json.dumps({"sources": record["sources"]}, ensure_ascii=False),
            record["feedback"]["value"] if record["feedback"] else "",
            ";".join(record["feedback"]["choice"]) if record["feedback"] else "",
            record["feedback"]["message"] if record["feedback"] else "",
            timestamp,
            RecordState.PRESENT.value,
            json.dumps(llm_context, ensure_ascii=False),
            json.dumps(generated_media),
        ]
        self.insert_record(record_value, thread_pool_executor=self.thread_executor)
        return record_id, ConversationInfo(id=conversation_id, name=conversation_name, timestamp=timestamp)  # type: ignore

    @log_query_time
    def update_answer(
        self,
        auth_identifier: str,
        conversation_id: str,
        message_id: str,
        answer: str,
    ) -> None:
        update_query = (
            UpdateQueryBuilder(self.dataset)
            .add_set_cols([(self.col("answer"), answer)])
            .add_conds(
                [
                    WhereCondition(
                        column=self.col("conversation_id"),
                        value=conversation_id,
                        operator=Operator.EQ,
                    ),
                    WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                    WhereCondition(column=self.col("message_id"), value=str(message_id), operator=Operator.EQ),
                ]
            )
            .build()
        )
        try:
            self.executor.query_to_df(update_query, post_queries=get_post_queries(self.dataset))
        except Exception as err:
            logger.exception(err, log_conv_id=True)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_query_time
    def get_conversation_history(self, auth_identifier: str, conversation_id: str):
        conversation = self.get_conversation(auth_identifier=auth_identifier, conversation_id=conversation_id)
        agents_files_uploads: Optional[Dict[str, Dict[str, MediaSummary]]] = None
        if conversation:
            messages = []
            if conversation.get("media_summaries"):
                # Handle case of initial media summarization
                # TODO: consider removing media_summaries
                messages.append(
                    LlmHistory(
                        input="",
                        output=json.dumps({"uploaded_docs": conversation["media_summaries"]}),
                    )
                )
            for item in conversation["data"]:
                if item.get("answer"):
                    if item.get("uploaded_docs"):
                        # Handle the case when the user uploads docs with a query
                        messages.append(
                            LlmHistory(
                                input=json.dumps({"uploaded_docs": item["uploaded_docs"], "query": item["query"]}),
                                output=item["answer"],
                            )
                        )
                    else:
                        messages.append(LlmHistory(input=item["query"], output=item["answer"]))
                elif item.get("generated_media") and item.get("generated_media").get("images"):
                    messages.append(
                        LlmHistory(
                            input=item["query"], output=json.dumps({"generated_media_by_ai": item["generated_media"]})
                        )
                    )
                elif item.get("uploaded_docs") and not item.get("answer"):
                    # Handle the case when the user uploads docs for summary during an existing conversation
                    messages.append(
                        LlmHistory(input=item["query"], output=json.dumps({"uploaded_docs": item["uploaded_docs"]}))
                    )
                if item.get("agents_files_uploads"):
                    # Extract the uploaded files summary map from the agents
                    previously_uploaded_files = item.get("agents_files_uploads", {})
                    if not agents_files_uploads:
                        agents_files_uploads = {}
                    for file_name, file_summary in previously_uploaded_files.items():
                        if file_name not in agents_files_uploads:
                            agents_files_uploads[file_name] = {}
                        # Append uploaded files summary to the existing dictionary
                        agents_files_uploads[file_name].update(file_summary)

            return messages, conversation["name"], agents_files_uploads
        return [], None, None

    @log_query_time
    def _clear_conversation_history_permanent(self, auth_identifier: str, conversation_id: str) -> None:
        delete_query = (
            DeleteQueryBuilder(self.dataset)
            .add_conds(
                [
                    WhereCondition(
                        column=self.col("conversation_id"),
                        value=conversation_id,
                        operator=Operator.EQ,
                    ),
                    WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                    WhereCondition(column=self.col("message_id"), value=str(0), operator=Operator.NE),
                ]
            )
            .build()
        )

        cols_to_keep = [
            self.col("conversation_id"),
            self.col("conversation_name"),
            self.col("user"),
            self.col("timestamp"),
            self.col("message_id"),
        ]

        set_cols = [
            (
                column,
                ""
                if (column != "sources" and column != "filters")
                else json.dumps({"filters": None}, ensure_ascii=False)
                if column == "filters"
                else json.dumps({"sources": []}, ensure_ascii=False),
            )
            for column in self.columns
            if column not in cols_to_keep
        ]

        set_cols.append((self.col("state"), RecordState.CLEARED.value))

        set_empty_record_query = (
            UpdateQueryBuilder(self.dataset)
            .add_conds(
                [
                    WhereCondition(
                        column=self.col("conversation_id"),
                        value=conversation_id,
                        operator=Operator.EQ,
                    ),
                    WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                    WhereCondition(column=self.col("message_id"), value=str(0), operator=Operator.EQ),
                ]
            )
            .add_set_cols(set_cols)
            .build()
        )
        try:
            self.executor.query_to_df(delete_query, post_queries=get_post_queries(self.dataset))
            self.executor.query_to_df(set_empty_record_query, post_queries=get_post_queries(self.dataset))
        except Exception as err:
            logger.exception(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_query_time
    def _clear_conversation_history_non_permanent(self, auth_identifier: str, conversation_id: str) -> None:
        update_query = (
            UpdateQueryBuilder(self.dataset)
            .add_set_cols([(self.col("state"), RecordState.CLEARED.value)])
            .add_conds(
                [
                    WhereCondition(
                        column=self.col("conversation_id"),
                        value=conversation_id,
                        operator=Operator.EQ,
                    ),
                    WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                ]
            )
            .build()
        )

        try:
            self.executor.query_to_df(update_query, post_queries=get_post_queries(self.dataset))
        except Exception as err:
            logger.exception(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_query_time
    def clear_conversation_history(self, auth_identifier: str, conversation_id: str) -> None:
        # TODO: This one is tricky as we need a separate conversation dataset for this
        # Basic solution would be to delete everything except for the first message
        # and then update the message with blank answer and query
        if self.permanent_delete:
            self._clear_conversation_history_permanent(auth_identifier=auth_identifier, conversation_id=conversation_id)
        else:
            self._clear_conversation_history_non_permanent(
                auth_identifier=auth_identifier, conversation_id=conversation_id
            )

    @log_query_time
    def _get_conversation_media_paths(self, auth_identifier: str, conversation_id: Optional[str] = None) -> List[str]:
        if dataiku_api.webapp_config.get("upload_folder") is None:
            return []
        media_columns = [
            self.col("llm_context"),
            self.col("generated_media"),
        ]
        eq_cond = [
            WhereCondition(column=self.col("user"), operator=Operator.EQ, value=auth_identifier),
            WhereCondition(
                column=self.col("state"),
                operator=Operator.OR,
                value=[RecordState.PRESENT.value, RecordState.CLEARED.value],
            ),
        ]
        if conversation_id:
            eq_cond.append(
                WhereCondition(column=self.col("conversation_id"), operator=Operator.EQ, value=conversation_id)
            )
        media_df = self.select_columns_from_dataset(
            column_names=media_columns,
            eq_cond=eq_cond,
        )
        if isinstance(media_df, Generator):
            return []
        if media_df.empty:
            return []

        llm_context = media_df["llm_context"].dropna()
        generated_media = media_df["generated_media"].dropna()
        uploaded = ConversationSQL.get_paths_from_history(
            series=llm_context, data_key="media_qa_context", additional_file_name="metadata_path"
        )
        generated = ConversationSQL.get_paths_from_history(
            series=generated_media, data_key="images", additional_file_name="referred_file_path"
        )
        return uploaded + generated

    @log_query_time
    def delete_user_conversation(self, auth_identifier: str, conversation_id: str) -> None:
        if self.permanent_delete:
            media_to_delete = self._get_conversation_media_paths(auth_identifier, conversation_id)
            if media_to_delete:
                logger.debug(f"Deleting media files: {media_to_delete}")
                ConversationSQL.delete_files(media_to_delete)
            query = (
                DeleteQueryBuilder(self.dataset)
                .add_conds(
                    [
                        WhereCondition(
                            column=self.col("conversation_id"),
                            value=conversation_id,
                            operator=Operator.EQ,
                        ),
                        WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                    ]
                )
                .build()
            )
        else:
            query = (
                UpdateQueryBuilder(self.dataset)
                .add_set_cols([(self.col("state"), RecordState.DELETED.value)])
                .add_conds(
                    [
                        WhereCondition(
                            column=self.col("conversation_id"),
                            value=conversation_id,
                            operator=Operator.EQ,
                        ),
                        WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                    ]
                )
                .build()
            )
        try:
            self.executor.query_to_df(query, post_queries=get_post_queries(self.dataset))
        except Exception as err:
            logger.exception(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_query_time
    def delete_all_user_conversations(self, auth_identifier: str) -> None:
        if self.permanent_delete:
            media_to_delete = self._get_conversation_media_paths(auth_identifier)
            if media_to_delete:
                logger.debug(f"Deleting media files: {media_to_delete}")
                ConversationSQL.delete_files(media_to_delete)
            query = (
                DeleteQueryBuilder(self.dataset)
                .add_conds([WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ)])
                .build()
            )
        else:
            query = (
                UpdateQueryBuilder(self.dataset)
                .add_set_cols([(self.col("state"), RecordState.DELETED.value)])
                .add_conds(
                    [
                        WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                    ]
                )
                .build()
            )
        try:
            self.executor.query_to_df(query, post_queries=get_post_queries(self.dataset))
        except Exception as err:
            logger.exception(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_query_time
    def update_feedback(
        self,
        auth_identifier: str,
        conversation_id: str,
        message_id: str,
        feedback: Feedback,
    ) -> None:
        update_query = (
            UpdateQueryBuilder(self.dataset)
            .add_set_cols(
                [
                    (self.col("feedback_value"), feedback["value"]),
                    (self.col("feedback_message"), feedback["message"]),
                    (self.col("feedback_choice"), ";".join(feedback["choice"])),
                ]
            )
            .add_conds(
                [
                    WhereCondition(
                        column=self.col("conversation_id"),
                        value=conversation_id,
                        operator=Operator.EQ,
                    ),
                    WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ),
                    WhereCondition(column=self.col("message_id"), value=str(message_id), operator=Operator.EQ),
                ]
            )
            .build()
        )
        try:
            self.executor.query_to_df(update_query, post_queries=get_post_queries(self.dataset))
        except Exception as err:
            logger.exception(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    def __validate_columns(self, column_names: List[str]):
        return all(name in self.columns for name in column_names)

    def convert_query_result_to_conversation(self, result: pd.DataFrame, only_present: bool = True):
        if result.empty:
            return None
        else:
            result = result.sort_values(by=self.col("message_id"), ascending=True)
            id = result[self.col("conversation_id")].iloc[0]
            name = result[self.col("conversation_name")].iloc[0]
            timestamp = result[self.col("timestamp")].iloc[0]
            auth_identifier = result[self.col("user")].iloc[0]
            question_data: List[QuestionData] = []
            is_media_qa = False
            media_summaries = None
            for index, row in result.iterrows():
                if only_present and row[self.col("state")] != RecordState.PRESENT.value:
                    continue
                query = row[self.col("question")]
                answer = row[self.col("answer")]
                feedback_value = row[self.col("feedback_value")]
                feedback_choice = row[self.col("feedback_choice")]
                feedback_message = row[self.col("feedback_choice")] # TODO: investigate
                message_timestamp = row[self.col("timestamp")]
                record_id = row[self.col("message_id")]
                feedback = None
                generated_media = ConversationSQL.load_generated_media(row[self.col("generated_media")])
                # TODO: should all happen in load_sources
                sources = ConversationSQL.load_sources(row[self.col("sources")])
                filters = ConversationSQL.load_filters(row[self.col("filters")])
                llm_context = ConversationSQL.load_llm_context(row[self.col("llm_context")])
                agents_selection = llm_context.get("agents_selection") if llm_context else None

                uploaded_docs = None
                if llm_context is not None:
                    agents_files_uploads = llm_context.get("agents_files_uploads")
                    media_qa_context = llm_context.get("media_qa_context") if llm_context else None
                    if media_qa_context is not None and index == 0:
                        is_media_qa = True
                        media_summaries = ConversationSQL.get_uploaded_docs(media_qa_context)
                    if uploaded_docs := llm_context.get("uploaded_docs", []):
                        uploaded_docs = ConversationSQL.get_uploaded_docs(uploaded_docs)
                if feedback_value:
                    feedback = Feedback(
                        value=feedback_value,
                        message=feedback_message if feedback_message else "",
                        choice=ConversationSQL.format_feedback_choice(feedback_choice),
                    )
                question_data.append(
                    PortalQuestionData(
                        id=record_id,
                        query=query,
                        answer=answer,
                        filters=filters,
                        sources=sources,
                        feedback=feedback,
                        agents_selection=agents_selection,
                        agents_files_uploads=agents_files_uploads,
                        timestamp=message_timestamp,
                        generated_media=generated_media,
                        uploaded_docs=uploaded_docs,
                    )
                )
            return Conversation(
                id=id,
                name=name,
                timestamp=timestamp,
                auth_identifier=auth_identifier,
                data=question_data,
                media_summaries=media_summaries,
                conversation_type=ConversationType.MEDIA_QA if is_media_qa else ConversationType.GENERAL,
            )

    @staticmethod
    def format_feedback_choice(choice: Optional[str]):
        if choice is None or choice == "":
            return []
        return [value.strip() for value in choice.split(";")]

    @staticmethod
    def load_source_items(items: List[ToolCall]):
        result = []
        for item in items:
            tool_items = item.get("items", [])
            for tool_item in tool_items:
                if tool_item.get("images"):
                    tool_item["images"] = load_images_data(tool_item.get("images") or [])
            result.append(ToolCall(toolCallDescription=item.get("toolCallDescription", ""), items=tool_items))
        return result

    @staticmethod
    def load_sources(sources_str: str):
        if not sources_str:
            return []
        sources_json = json.loads(sources_str)
        sources = sources_json.get("sources", [])
        result = [
            PortalSource(
                name=source.get("name", ""),
                id=source.get("id", ""),
                type=source.get("type", ""),
                items=ConversationSQL.load_source_items(source.get("items", [])),
                answer=source.get("answer", ""),
            )
            for source in sources
        ]
        return result

    @staticmethod
    def load_filters(filters_str: str):
        if not filters_str:
            return None
        filters_json = json.loads(filters_str)
        filters = filters_json.get("filters", None)
        return filters

    @staticmethod
    def load_llm_context(llm_context: Optional[str]):
        return json.loads(llm_context) if llm_context else None

    @staticmethod
    def load_generated_media(generated_media: Optional[str]):
        return json.loads(generated_media) if generated_media else None

    @staticmethod
    def get_uploaded_docs(uploaded_media: List[MediaSummary]) -> Union[List[MediaSummary], None]:
        if not uploaded_media:
            return None
        folder = dataiku_api.folder_handle
        uploaded_doc_context = []
        try:
            # List all files in the managed folder
            file_list = folder.list_paths_in_partition()
            for media in uploaded_media:
                metadata_path: Optional[str] = media.get("metadata_path")
                if not metadata_path:
                    raise ValueError("metadata_path is missing")
                # Handle the case the file was deleted
                if "/" + metadata_path in file_list:
                    extract_summary = folder.read_json(metadata_path)
                    new_summary: MediaSummary = {**media, **extract_summary}
                    uploaded_doc_context.append(new_summary)
                else:
                    logger.info(f"The file {metadata_path} does not exist in the folder.")
                    uploaded_doc_context.append({**media, "is_deleted": True})
            return uploaded_doc_context
        except Exception as e:
            logger.exception(f"Error occurred while retrieving uploaded media: {e}")
            return None

    @staticmethod
    def get_paths_from_history(series: pd.Series, data_key: str, additional_file_name: str) -> List[str]:
        def get_paths(row):
            if not row:
                return []
            media_qa_context = json.loads(row).get(data_key, [])
            return itertools.chain(
                (media.get(additional_file_name) for media in media_qa_context if media.get(additional_file_name)),
                (media.get("file_path") for media in media_qa_context if media.get("file_path")),
            )

        series = series.map(get_paths)  # type: ignore
        unique_values = list(set(item for sublist in series for item in sublist if item))
        return unique_values

    @staticmethod
    def delete_files(file_paths: List[str]) -> None:
        folder = dataiku_api.folder_handle
        for path in file_paths:
            try:
                folder.delete_path(path)
            except Exception as e:
                logger.exception(f"Error occurred while deleting file: {e}")


conversation_sql_manager = ConversationSQL()
