from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import dataiku
import pandas as pd
from common.backend.constants import INSERT_WITH_APPEND_MODE_DIALECTS, PROFILE_LAST_UPDATED, SUPPORTED_DIALECT_VALUES
from common.backend.db.sql.queries import (
    InsertQueryBuilder,
    WhereCondition,
    _get_where_and_cond,
    get_post_queries,
)
from common.backend.models.base import (
    ConversationInfo,
    ConversationInsertInfo,
    Feedback,
    MessageInsertInfo,
    QuestionData,
)
from common.backend.utils.sql_timing import log_execution_time
from common.llm_assist.logging import logger
from dataiku import Dataset, SQLExecutor2
from dataiku.core.sql import QueryReader
from dataiku.sql import Column, Dialects, SelectQuery, toSQL
from werkzeug.exceptions import BadRequest


class GenericSQLManager(ABC):
    def __init__(self, config: Dict[str, Any], columns: List[str], dataset_conf_id: str, default_project_key: str):
        self._config = config
        self._columns = columns
        self._dataset_name = str(config.get(dataset_conf_id))
        self._default_project_key = default_project_key
        self._dataset = Dataset(project_key=self.default_project_key, name=self.dataset_name)
        self._dialect = str(self._dataset.get_config().get("type"))
        
        # 'appendMode' unlocks the capability to append records when using a dataset writer:
        self.insert_with_append_mode = False
        self._dataset.spec_item["appendMode"] = False
        if self._dialect in INSERT_WITH_APPEND_MODE_DIALECTS:
            self.insert_with_append_mode = True
            self._dataset.spec_item["appendMode"] = True
            logger.info(f"Dialect '{self._dialect}' detected: data insert on the dataset '{dataset_conf_id}' will will be done in 'appendMode'")
            
        self._is_upper = False
        self.__verify(dataset_conf_id)
        self._executor = SQLExecutor2(dataset=self.dataset)
        self.__init_dataset()
        self.__get_column_names()

    @property
    def config(self) -> Dict[str, Any]:
        return self._config

    @property
    def columns(self) -> List[str]:
        return self._columns

    @columns.setter
    def columns(self, value) -> None:
        self._columns = value

    @property
    def dataset_name(self) -> Optional[str]:
        return self._dataset_name

    @property
    def default_project_key(self) -> str:
        return self._default_project_key

    @property
    def is_upper(self) -> bool:
        return self._is_upper

    @is_upper.setter
    def is_upper(self, value) -> None:
        self._is_upper = value

    @property
    def dataset(self) -> Dataset:
        return self._dataset

    @property
    def dialect(self) -> str:
        return self._dialect

    @property
    def executor(self) -> SQLExecutor2:
        return self._executor

    def __verify(self, dataset) -> None:
        if self.dataset_name == "None" or not self.dataset_name:
            msg = f"{dataset} Database name should not be null"
            logger.error(msg)
            raise ValueError(msg)
        self.__check_db_exists()
        self.__check_supported_dialect()

    def __check_db_exists(self) -> bool:
        client = dataiku.api_client()
        project = client.get_project(self.default_project_key)
        data = project.list_datasets()
        datasets = [item.name for item in data]
        logger.debug(f"Searching for {self.dataset_name} in this project datasets: {datasets}")
        if self._dataset_name in datasets:
            return True
        else:
            msg = f"{self._dataset_name} dataset does not exist"
            logger.error(msg)
            raise ValueError(msg)

    def __check_supported_dialect(self) -> None:
        if self.dialect in SUPPORTED_DIALECT_VALUES:
            return
        logger.warn(f"Dataset Type {self.dialect} is not supported")

    def __init_dataset(self) -> None:
        try:
            self.dataset.read_schema(raise_if_empty=True)
        except Exception:
            logger.info(f"Initializing the {self.dataset_name} dataset schema")
            if self.dialect == Dialects.ORACLE:
                from common.backend.db.sql.oracle import OracleParams, OracleSQLManager
                oracle_params = OracleParams(
                    dataset_name=str(self.dataset_name),
                    executor=self.executor,
                    dataset=self.dataset,
                    columns=self.columns,
                )
                oracle_manager = OracleSQLManager(oracle_params)
                oracle_manager.setup_oracle_table()
            else:
                df = self.__get_init_df()
                if PROFILE_LAST_UPDATED in [c.lower() for c in df.columns]:
                    df[PROFILE_LAST_UPDATED] = pd.to_datetime(df[PROFILE_LAST_UPDATED])
                self.dataset.write_with_schema(df=df)

    # TODO: can this be removed from queries.py?
    def columns_in_uppercase(self) -> bool:
        try:
            q = SelectQuery().select_from(self.dataset).select(Column("*")).limit(1)
            query = toSQL(q, dataset=self.dataset)
            df = self.executor.query_to_df(query)
            return bool(all(c.isupper() for c in df.columns))
        except Exception as e:
            logger.exception(f"Error when trying to determine case of table columns: {e}")
            return False

    def __get_column_names(self) -> None:
        if self.columns_in_uppercase():
            logger.warn(f"The dataset '{self.dataset.name}' must be handled in uppercase.")
            self.is_upper = True
            self.columns = [col.upper() for col in self.columns]

    def __get_init_df(self) -> pd.DataFrame:
        data: Dict[str, List] = {col: [] for col in self.columns}
        return pd.DataFrame(data=data, columns=self.columns, dtype=str)

    def col(self, col_name: str) -> str:
        if self.is_upper:
            return col_name.upper()
        return col_name

    def execute(
        self,
        query_raw,
        format_: Literal["dataframe", "iter"] = "dataframe",
    ) -> Union[pd.DataFrame, QueryReader]:
        try:
            query = toSQL(query_raw, dataset=self.dataset)
        except Exception as err:
            raise BadRequest(f"Error when generating SQL query: {err}")

        if format_ == "dataframe":
            try:
                query_result = self.executor.query_to_df(query=query).fillna("")
                return query_result
            except Exception as err:
                raise BadRequest(f"Error when generating SQL query: {err}")
        elif format_ == "iter":
            try:
                query_result = self.executor.query_to_iter(query=query).iter_tuples()
                return query_result
            except Exception as err:
                raise BadRequest(f"Error when executing SQL query: {err}")

    def execute_commit(self) -> None:
        return

    # TODO: refactor this method as it is a duplicate of the same method in conversation.py
    def __validate_columns(self, column_names: List[str]):
        return all(name in self.columns for name in column_names)

    # TODO: refactor this method as it is a duplicate of the same method in conversation.py
    @log_execution_time
    def select_columns_from_dataset(  # noqa: PLR0917 too many positional arguments
        self,
        column_names: List[str],
        distinct: bool = False,
        eq_cond: List[WhereCondition] = [],
        format_: Literal["dataframe", "iter"] = "dataframe",
        limit: Optional[int] = None,
        order_by: Optional[str] = None,
    ) -> pd.DataFrame:
        self.__validate_columns(column_names)
        columns_to_select = [Column(str(col)) for col in column_names]

        select_query = SelectQuery()
        if distinct:
            select_query.distinct()
        select_query.select_from(self.dataset)

        if column_names == self.columns:
            select_query.select(Column("*"))
        else:
            select_query.select(columns_to_select)

        where_cond = _get_where_and_cond(eq_cond)

        select_query.where(where_cond)

        if limit:
            select_query.limit(limit)

        if order_by:
            order_by_col = Column(str(order_by))
            select_query.order_by(order_by_col)
        # print("sql query: ", select_query)
        response = self.execute(select_query, format_=format_)
        if type(response) is not pd.DataFrame:
            return pd.DataFrame()
        return response

    def insert_record_with_dataframe(self, record_df: pd.DataFrame):
        """
        This function must be used only when the database requires to insert data 
            using 'Dataset appendMode' instead of SQL INSERT queries
        """
        try:
            string_columns = record_df.columns.difference([self.col(PROFILE_LAST_UPDATED)])
            record_df[string_columns] = record_df[string_columns].astype(str)
            logger.info(f"Inserting data using 'Dataset appendMode' on dataset '{self.dataset.full_name}' table")
            self.dataset.write_dataframe(record_df, infer_schema=False, drop_and_create=True)
            logger.debug("Data successfully inserted")
        except Exception as e:
            error_message = f"Error when inserting record with 'Dataset appendMode': '{e}' (The record was {dict(record_df.loc[0])})."
            logger.exception(error_message)
            raise BadRequest(error_message)

    def insert_record_with_sql_query(self, insert_query: Any):
        try:
            logger.info(f"Inserting data using 'SQL query' on dataset '{self.dataset.full_name}' table")
            self.executor.query_to_df(insert_query, post_queries=get_post_queries(self.dataset))
        except Exception as e:
            error_message = f"Error when inserting record with 'SQL Query': '{e}' (The query was {insert_query})."
            logger.exception(error_message)
            raise BadRequest(error_message)
    
    def insert_record(self, record_value: List[Any], thread_pool_executor: Optional[ThreadPoolExecutor]=None):
        if self.insert_with_append_mode:
            record_df = pd.DataFrame([dict(zip(self.columns, record_value))])
            if thread_pool_executor:
                thread_pool_executor.submit(self.insert_record_with_dataframe, record_df)
                logger.debug("Insert using 'Dataset appendMode' submitted to the executor")
            else:
                self.insert_record_with_dataframe(record_df)
                
        else:
            insert_query = (
                InsertQueryBuilder(self.dataset).add_columns(self.columns).add_values(values=[record_value]).build()
                )
            if thread_pool_executor:
                logger.debug("Insert using 'SQL query' submitted to the executor")
                thread_pool_executor.submit(self.insert_record_with_sql_query, insert_query)
            else:
                self.insert_record_with_sql_query(insert_query)

class GenericMessageSQL(GenericSQLManager):
    @abstractmethod
    def add_message(
        self,
        user: str,
        message_info: MessageInsertInfo,
        timestamp: str
    ) -> str:
        raise NotImplementedError("Subclasses must implement add_message method")

    @abstractmethod
    def get_all_conversation_messages(self, platform: str, user: str, conversation_id: str, only_present: bool= True):
        raise NotImplementedError("Subclasses must implement get_all_conversation_messages method")

    @abstractmethod
    def delete_conversation_messages(self, platform: str, user: str, conversation_id: str):
        raise NotImplementedError("Subclasses must implement delete_conversation_messages method")


class GenericLoggingDatasetSQL(GenericSQLManager):
    @abstractmethod
    def get_user_conversations(self, auth_identifier: str) -> List[ConversationInfo]:
        raise NotImplementedError("Subclasses must implement get_user_conversations method")

    @abstractmethod
    def get_conversation(self, auth_identifier: str, conversation_id: str, only_present: bool = True):
        raise NotImplementedError("Subclasses must implement get_conversation method")

    @abstractmethod
    def add_record(  # noqa: PLR0917 too many positional arguments
        self,
        record: QuestionData,
        auth_identifier: str,
        conversation_id: Optional[str],
        conversation_name: Optional[str],
        knowledge_bank_id: Optional[str],
        llm_id: Optional[str] = None,
    ) -> Tuple[str, ConversationInfo]:
        raise NotImplementedError("Subclasses must implement add_record method")

    @abstractmethod
    def update_answer(
        self,
        auth_identifier: str,
        conversation_id: str,
        message_id: str,
        answer: str,
    ) -> None:
        raise NotImplementedError("Subclasses must implement update_answer method")

    @abstractmethod
    def clear_conversation_history(self, auth_identifier: str, conversation_id: str) -> None:
        raise NotImplementedError("Subclasses must implement clear_conversation_history method")

    @abstractmethod
    def delete_user_conversation(self, auth_identifier: str, conversation_id: str) -> None:
        raise NotImplementedError("Subclasses must implement delete_user_conversation method")

    @abstractmethod
    def delete_all_user_conversations(self, auth_identifier: str) -> None:
        raise NotImplementedError("Subclasses must implement delete_all_user_conversations method")

    @abstractmethod
    def update_feedback(
        self,
        auth_identifier: str,
        conversation_id: str,
        message_id: str,
        feedback: Feedback,
    ) -> None:
        raise NotImplementedError("Subclasses must implement update_feedback method")


class GenericConversationSQL(GenericSQLManager):
    @abstractmethod
    def add_conversation(
        self,
        user: str,
        conversation_info: ConversationInsertInfo,
    ) -> ConversationInfo:
        raise NotImplementedError("Subclasses must implement add_conversation method")

    @abstractmethod
    def delete_conversation(self, user: str, conversation_id: str):
        raise NotImplementedError("Subclasses must implement delete_conversation method")

    # @abstractmethod
    # def get_conversation(self, user: str, conversation_id: str):
    #     raise NotImplementedError("Subclasses must implement get_conversation method")

    @abstractmethod
    def get_all_user_conversations(self, user: str, present_only: bool = True) -> List[ConversationInfo]:
        raise NotImplementedError("Subclasses must implement get_all_user_conversations method")
    
    @abstractmethod
    def update_conversation_metadata(self, user: str, conversation_id: str, column_updates: Dict[str, Any], platform: Optional[str]):
        raise NotImplementedError("Subclasses must implement update_conversation_metadata method")


class GenericUserProfileSQL(GenericSQLManager):
    @abstractmethod
    def get_user_profile(self, auth_identifier: str) -> Optional[Dict]:
        raise NotImplementedError("Subclasses must implement get_user_profile method")

    @abstractmethod
    def add_user_profile(
        self,
        user: str,
        profile: Dict,
    ):
        raise NotImplementedError("Subclasses must implement add_user_profile method")

    @abstractmethod
    def update_user_profile(self, user: str, profile: Dict):
        raise NotImplementedError("Subclasses must implement update_user_profile method")

    @abstractmethod
    def update_generated_images_count_in_db(
        self, user: str, user_profile: Dict[str, Any], num_generated_images: int, config: Dict[str, str]
    ) -> Dict[str, Any]:
        raise NotImplementedError("Subclasses must implement update_generated_images_count_in_db method")


class GenericFeedbackSQL(GenericSQLManager):
    @abstractmethod
    def add_feedback(
        self,
        timestamp: datetime,
        user: str,
        message: str,
        knowledge_bank_id: Optional[str] = None,
        llm_id: Optional[str] = None,
    ):
        raise NotImplementedError("Subclasses must implement add_feedback method")
