import json
from typing import Any, Dict, Optional

from common.backend.db.sql.queries import (
    Operator,
    UpdateQueryBuilder,
    WhereCondition,
    _get_where_and_cond,
    get_post_queries,
)
from common.backend.db.sql.tables_managers import GenericUserProfileSQL
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.date_utils import get_string_date_interpreted_as_datetime
from common.backend.utils.sql_timing import log_execution_time
from common.backend.utils.user_profile_utils import update_user_profile_generated_media_info
from common.llm_assist.logging import logger
from dataiku.sql import Column, SelectQuery
from werkzeug.exceptions import BadRequest

COLUMNS = ["user", "profile", "last_updated"]

USER_PROFILE_DATASET_CONF_ID = "user_profile_dataset"


class UserProfileSQL(GenericUserProfileSQL):
    def __init__(self):
        config = dataiku_api.webapp_config
        super().__init__(
            config=config,
            columns=COLUMNS,
            dataset_conf_id=USER_PROFILE_DATASET_CONF_ID,
            default_project_key=dataiku_api.default_project_key,
        )

    @log_execution_time
    def get_user_profile(self, auth_identifier: str) -> Optional[Dict]:
        eq_cond = [WhereCondition(column=self.col("user"), value=auth_identifier, operator=Operator.EQ)]
        where_cond = _get_where_and_cond(eq_cond)
        select_query = SelectQuery().select_from(self.dataset).select(Column(self.col("profile"))).where(where_cond)
        try:
            result = self.execute(select_query, format_="dataframe")
            if not result.empty:
                profile_str = result.iloc[0].to_dict()[self.col("profile")]
                profile_dict = dict(json.loads(profile_str))
                return profile_dict
            else:
                logger.debug(f"No profile found for user {auth_identifier}")
                return None
        except Exception as err:
            logger.error(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_execution_time
    def add_user_profile(
        self,
        user: str,
        profile: Dict,
    ):
        logger.debug(f"Adding the user profile '{user}'...")
        last_updated = get_string_date_interpreted_as_datetime(self.dialect)
            
        record_value = [
            user,
            json.dumps(profile, ensure_ascii=False),
            last_updated,
        ]
        self.insert_record(record_value)

    @log_execution_time
    def update_user_profile(self, user: str, profile: Dict):
        eq_cond = [
            WhereCondition(column=self.col("user"), value=user, operator=Operator.EQ),
        ]
        last_updated = get_string_date_interpreted_as_datetime(self.dialect)
        
        update_query = (
            UpdateQueryBuilder(self.dataset)
            .add_set_cols(
                [
                    (self.col("profile"), json.dumps(profile, ensure_ascii=False)),
                    (self.col("last_updated"), last_updated),
                ]
            )
            .add_conds(eq_cond)
            .build()
        )
        try:
            logger.debug(f"Updating user profile for user {user}")
            logger.debug(f"Update query: {update_query}")
            self.executor.query_to_df(update_query, post_queries=get_post_queries(self.dataset))
        except Exception as err:
            logger.error(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    def update_generated_images_count_in_db(
        self, user: str, user_profile: Dict[str, Any], num_generated_images: int, config: Dict[str, str]
    ) -> Dict[str, Any]:
        # Extract the maximum images allowed per user per week from the config
        max_images_per_week = int(config.get("max_images_per_user_per_week", 0))
        logger.debug(f"Max images per week: {max_images_per_week}")
        logger.debug(f"Number of images generated: {num_generated_images}")
        # Proceed only if there is a valid max image count specified
        if max_images_per_week > 0 and num_generated_images > 0:
            user_profile_exists = True
            if user_profile.get("new_user_profile", False):
                user_profile_exists = False
            user_profile = update_user_profile_generated_media_info(user_profile, num_generated_images)
            if user_profile_exists:
                # Update the user profile with the new media data
                self.update_user_profile(user, user_profile)
            else:
                # Create a new user profile with the generated media data
                self.add_user_profile(
                    user, {"generated_media_info": user_profile["generated_media_info"]}
                )
        return user_profile


user_profile_sql_manager = UserProfileSQL()
