from typing import Any, List, Literal, Optional, Tuple, Union

import pandas as pd
from common.backend.db.sql.queries import WhereCondition, _get_where_and_cond  #TODO: this import needs to change
from common.backend.utils.dataiku_api import dataiku_api
from dataiku import Dataset, SQLExecutor2
from dataiku.sql import Column, SelectQuery, toSQL
from dataiku.sql.expression import Operator
from werkzeug.exceptions import BadRequest


class BsDsSql:
    OPERATORS_MAP = {
        "=": Operator.EQ,
        "!=": Operator.NE,
        ">=": Operator.GE,
        "<": Operator.LT,
    }

    WhereClauseType = List[Tuple[str, str, Union[str, int, float]]]

    def __init__(self):
        #TODO: check table is correct.
        sql_retrieval_table_name = dataiku_api.webapp_config.get("sql_retrieval_table_name")
        if sql_retrieval_table_name is not None:
            self.dataset = Dataset(
            project_key=dataiku_api.default_project_key, name=sql_retrieval_table_name
            )
            self.executor = SQLExecutor2(dataset=self.dataset)
        else:
            self.dataset = None
            self.executor = None

    def execute(
        self,
        query_raw,
        format_: Literal["dataframe", "iter"] = "dataframe",
    ):
        try:
            query = toSQL(query_raw, dataset=self.dataset)
        except Exception as err:
            raise BadRequest(f"Error when generating SQL query: {err}")
        if format_ == "dataframe":
            try:
                query_result = self.executor.query_to_df(query=query).fillna("")
                return query_result
            except Exception as err:
                raise BadRequest(f"Error when generating SQL query: {err}")
        elif format_ == "iter":
            try:
                query_result = self.executor.query_to_iter(query=query).iter_tuples()
                return query_result
            except Exception as err:
                raise BadRequest(f"Error when executing SQL query: {err}")

    def select_columns_from_dataset(  # noqa: PLR0917 too many positional arguments
        self,
        column_names: Union[List[str], str],
        distinct: bool = False,
        cond: List[WhereCondition] = [],
        format_: Literal["dataframe", "iter"] = "dataframe",
        limit: Optional[int] = None,
        order_by: Optional[str] = None,
    ):
        columns_to_select = [Column(str(col)) for col in column_names]

        select_query = SelectQuery()
        if distinct:
            select_query.distinct()
        select_query.select_from(self.dataset)

        select_query.select(columns_to_select)
        if cond:
            where_cond = _get_where_and_cond(cond)
            select_query.where(where_cond)
        if limit:
            select_query.limit(limit)
        if order_by:
            order_by_col = Column(str(order_by))
            select_query.order_by(order_by_col)

        return self.execute(select_query, format_=format_)

    def parse_conditions(self, where_):
        cond = []
        for w in where_:
            col = w[0].strip()
            operator = self.OPERATORS_MAP[w[1].replace("==", "=").strip()]
            value = w[2].strip() if isinstance(w[2], str) else w[2]
            cond.append(WhereCondition(column=col, operator=operator, value=value))
        return cond

    def sql_to_df(
        self,
        select: Union[List[str], str],
        where_: Union[WhereClauseType, None] = None,
        limit: Union[int,None] = None,
        order_by: Union[str, None] = None,
    ) -> Union[pd.DataFrame, None, Any]:
        cond = self.parse_conditions(where_) if where_ is not None else []
        df = self.select_columns_from_dataset(
                column_names=select, cond=cond, order_by=order_by, limit=limit
        )
        #TODO: error handel 
        return df



bs_ds_sql = BsDsSql()

# if __name__ == '__main__':
    # filter_dict = {'prospect_first_name': 'Petr',
    # 'prospect_last_name': 'Kadlec'}
    # where_ = [(k,'==', v) for (k,v) in filter_dict.items()]

    # df = bs_ds_sql.sql_to_df(
    #                 select = '*',
    #                 where_ = where_,
    # )
    # prospect_info = df.to_dict('records')