import json
import uuid
from typing import Any, Dict, List, Optional

import pandas as pd
from answers.backend.db.messages import messages_sql_manager
from answers.backend.db.sql.dataset_schemas import ANSWERS_DATASETS_METADATA, CONVERSATION_DATASET_CONF_ID
from common.backend.db.sql.queries import (
    DeleteQueryBuilder,
    UpdateQueryBuilder,
    WhereCondition,
    get_post_queries,
)
from common.backend.db.sql.tables_managers import GenericConversationSQL
from common.backend.models.base import (
    APIApplicationContext,
    APIConversationMetadata,
    APIConversationsResponse,
    APIConversationTitle,
    APISingleConversationResponse,
    ConversationInfo,
    ConversationInsertInfo,
    RecordState,
)
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.date_utils import get_string_date_interpreted_as_string
from common.backend.utils.sql_timing import log_execution_time
from common.llm_assist.logging import logger
from dataiku.sql.expression import Operator
from werkzeug.exceptions import BadRequest


#TODO refactor this method as it is a duplicate of the same method in messages.py
def stringify(data) -> Optional[str]:
    return json.dumps(data, ensure_ascii=False) if data else None


class ConversationsSQL(GenericConversationSQL):
    def __init__(self):
        config = dataiku_api.webapp_config
        if not config.get("enable_answers_api", False) or not config.get(CONVERSATION_DATASET_CONF_ID):
            return
        super().__init__(config=config,
                         columns=ANSWERS_DATASETS_METADATA[CONVERSATION_DATASET_CONF_ID]["columns"],
                         dataset_conf_id=CONVERSATION_DATASET_CONF_ID,
                         default_project_key=dataiku_api.default_project_key)

        self.permanent_delete = self.config.get("permanent_delete", True)

    @log_execution_time
    def add_conversation(
        self,
        user: str,
        conversation_info: ConversationInsertInfo,
    ):
        logger.debug("Adding an API conversation...")
        conversation_id = str(uuid.uuid4())  # generate
        timestamp = get_string_date_interpreted_as_string(sql_dialect=self.dialect)
        conversation_name = conversation_info["name"]
        record_value = [
            conversation_id,
            user,
            conversation_info.get("platform"),
            timestamp,
            conversation_name,
            timestamp,
            RecordState.PRESENT.value,
            stringify(conversation_info.get("metadata")),
        ]
        self.insert_record(record_value)
        return ConversationInfo(
            id=conversation_id, name=conversation_name, timestamp=timestamp, created_at=timestamp, updated_at=timestamp #type: ignore
        )

    def update_conversation_metadata(self, user: str, conversation_id: str,
                                     column_updates: Dict[str, Any], platform: Optional[str]):
        logger.debug("Updating API conversation metadata...")
        conditions = [
            WhereCondition(column=self.col("user"), value=user, operator=Operator.EQ),
            WhereCondition(
                column=self.col("conversation_id"),
                value=conversation_id,
                operator=Operator.EQ,
            ),
        ]
        
        if platform is not None:
            conditions.append(
                WhereCondition(column=self.col("platform"), value=platform, operator=Operator.EQ),
            )
        
        set_options = [
            (self.col(column_name.lower()), column_value) 
            for column_name, column_value in column_updates.items()
            ]
        
        
        update_query = (
            UpdateQueryBuilder(self.dataset)
            .add_set_cols(set_options)
            .add_conds(conditions)
            .build()
        )
        try:
            self.executor.query_to_df(update_query, post_queries=get_post_queries(self.dataset))
            logger.info(f"Conversation successfully updated (conditions={conditions}, column_updates={column_updates}) ")
        except Exception as err:
            logger.error(err)
            raise BadRequest(f"Error when executing SQL query: {err}")
        
        pass

    @log_execution_time
    def delete_conversation_permanently(self, conditions: List[WhereCondition]):
        delete_query = DeleteQueryBuilder(self.dataset).add_conds(conds=conditions).build()
        try:
            self.executor.query_to_df(delete_query, post_queries=get_post_queries(self.dataset))
            logger.info(f"Conversation successfully deleted (conditions={conditions})")
        except Exception as err:
            logger.error(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_execution_time
    def update_conversation_as_deleted(self, conditions: List[WhereCondition]):
        update_query = (
            UpdateQueryBuilder(self.dataset)
            .add_set_cols([(self.col("state"), RecordState.DELETED.value)])
            .add_conds(conditions)
            .build()
        )
        try:
            self.executor.query_to_df(update_query, post_queries=get_post_queries(self.dataset))
            logger.info(f"Conversation successfully flagged as deleted (conditions={conditions})")
        except Exception as err:
            logger.error(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_execution_time
    def delete_conversation(self, platform: str, user: str, conversation_id: str):
        conditions = [
            WhereCondition(column=self.col("platform"), value=platform, operator=Operator.EQ),
            WhereCondition(column=self.col("user"), value=user, operator=Operator.EQ),
            WhereCondition(
                column=self.col("conversation_id"),
                value=conversation_id,
                operator=Operator.EQ,
            ),
        ]
        if self.permanent_delete:
            logger.info(f"Deleting permanently the user '{user}' conversation '{conversation_id}' ...")
            # first delete all the messages associated with the conversation and their files
            messages_sql_manager.delete_conversation_messages(platform, user, conversation_id)
            # then delete the conversation
            self.delete_conversation_permanently(conditions)
        else:
            logger.info(f"Flagging the user '{user}' conversation '{conversation_id}' as deleted ...")
            messages_sql_manager.update_messages_as_deleted(conditions)
            self.update_conversation_as_deleted(conditions)
    
    @log_execution_time
    def get_conversation_messages(self, platform: str, user: str, conversation_id: str, only_present: bool=True):
        def extract_conversation_metadata(query_result: pd.DataFrame):
            if query_result.empty:
                return {}
            else:
                query_result["metadata"].fillna({}, inplace=True)
                return query_result.loc[0].to_dict()
        conditions = [
            WhereCondition(column=self.col("platform"), value=platform, operator=Operator.EQ),
            WhereCondition(column=self.col("user"), value=user, operator=Operator.EQ),
            WhereCondition(
                column=self.col("conversation_id"),
                value=conversation_id,
                operator=Operator.EQ,
            ),
        ]
        format_ = "dataframe"
        conversation_metadata_df: pd.DataFrame = self.select_columns_from_dataset(column_names=self.columns, eq_cond=conditions, format_=format_)
        conversation_metadata = extract_conversation_metadata(conversation_metadata_df)
        conversation_messages = messages_sql_manager.get_all_conversation_messages(platform, user, conversation_id, only_present)
        conversation_context = APIApplicationContext(applicationId=platform)
        conversation_title = APIConversationTitle(
                original=conversation_metadata.get("conversation_name", ""),
                edited="", # TODO: Adapt when the title edition is implemented
                createdAt=conversation_metadata.get("created_at", "")
                )
        
        conversation_messages = APISingleConversationResponse(
            user=user,
            id=conversation_id,
            context=conversation_context,
            title=conversation_title,
            messages=conversation_messages,
            createdAt=conversation_metadata.get("created_at", ""),
            lastMessageAt=conversation_metadata.get("updated_at", ""),
            state=conversation_metadata.get("state", "")
        )
        return conversation_messages
    

    @log_execution_time
    def get_all_user_conversations(self, platform: str, user: str, only_present: bool=True) -> APIConversationsResponse:
        conditions = [
            WhereCondition(column=self.col("platform"), value=platform, operator=Operator.EQ),
            WhereCondition(column=self.col("user"), value=user, operator=Operator.EQ),
        ]
        present_condition = None
        if only_present:
            present_condition = WhereCondition(
                column=self.col("state"), value=RecordState.PRESENT.value, operator=Operator.EQ
            )
            conditions.append(present_condition)            
        
        format_ = "dataframe"
        order_by = self.col("updated_at")
        result: pd.DataFrame = self.select_columns_from_dataset(
            column_names=self.columns,
            distinct=True,
            eq_cond=conditions,
            format_=format_,
            order_by=order_by,
        )
        conversations_info = []
        conversation_context = APIApplicationContext(applicationId=platform)
        for index, row in result.iterrows():
            conversation_title = APIConversationTitle(
                original=row["conversation_name"],
                edited="", # TODO: Adapt when the title edition is implemented
                createdAt=row["created_at"]
                )
            conversations_info.append(
                APIConversationMetadata(
                    id=row["conversation_id"],
                    title=conversation_title,
                    createdAt=row["created_at"], 
                    lastMessageAt=row["updated_at"],
                    state=row["state"])
                    )
            
        all_user_conversations = APIConversationsResponse(
            user=user,
            context=conversation_context,
            conversations=conversations_info
            )

        return all_user_conversations
    
    @log_execution_time
    def delete_all_user_conversations(self, platform, user):
        logger.info(f"Deleting all conversations for user '{user}'")
        all_user_conversations = self.get_all_user_conversations(platform, user, only_present=False) 
        for conversation_info in all_user_conversations["conversations"]:
            self.delete_conversation(platform, user, conversation_info["id"])

conversations_sql_manager = ConversationsSQL()
