from typing import Any, Dict, List, Union, cast

import dataiku
import dataikuapi
import requests
from answers.backend.db.user_profile import user_profile_sql_manager
from answers.backend.routes.utils import return_ko, return_ok
from answers.backend.utils.config_utils import get_retriever_info
from answers.backend.utils.knowledge_filters import get_knowledge_bank_filtering_settings
from common.backend.constants import DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, MAX_N_UPLOAD_FILES, MAX_UPLOAD_SIZE_MB
from common.backend.models.base import RetrieverMode
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.llm_utils import get_llm_capabilities
from common.llm_assist.fallback import get_fallback_id
from common.llm_assist.log_decorators import log_http_request
from common.llm_assist.logging import logger
from flask import Blueprint, Response, request

# Use cast to tell mypy that the return value will be a list of dicts
languages = cast(List[Dict[str, str]], dataiku_api.webapp_config.get("user_profile_languages", []))

SUPPORTED_LANGUAGES = [{"key": item["from"], "value": item["to"]}
                       for item in languages]
config_blueprint = Blueprint("config", __name__, url_prefix="/config")


def format_feedback_choices(choices: List[Any]):
    choices_final = []
    for choice in choices:
        try:
            choices_final.append(str(choice))
        except Exception as e:
            logger.warn("choice can't be parsed to str")
    return choices_final


def get_user(headers):
    try:
        auth_info = dataiku.api_client().get_auth_info_from_browser_headers(headers)
        user = auth_info["authIdentifier"]
    except (dataikuapi.utils.DataikuException, requests.exceptions.HTTPError) as e:
        logger.error(f"Exception occurred: {str(e)}")
        user = "user_not_found"
    return user


def get_current_user_profile(headers):
    user = get_user(headers)
    if user == "user_not_found":
        return None
    profile = user_profile_sql_manager.get_user_profile(user)
    return profile
    

@config_blueprint.route("/get_ui_setup", methods=["GET"])
@log_http_request
def get_ui_setup() -> Response:
    """
    Fetches the configuration settings for UI setup from Dataiku and returns them.

    Returns:
        Response: A Flask response object containing the UI setup data.
    """
    config: Dict[str, Any] = dataiku_api.webapp_config
    examples: Union[str, List[str]] = config.get("example_questions", [])
    title: Union[str, None] = config.get("web_app_title")
    subtitle: Union[str, None] = config.get("web_app_subheading")
    lang: str = config.get("language", "en")
    placeholder: str = config.get("web_app_input_placeholder", "")
    feedback_positive_choices: List[str] = format_feedback_choices(
        config.get("feedback_positive_choices", [])
    )
    feedback_negative_choices: List[str] = format_feedback_choices(
        config.get("feedback_negative_choices", [])
    )
    filters_config = get_knowledge_bank_filtering_settings(dataiku_api.webapp_config.get("knowledge_bank_id")) if config.get(
        "retrieval_mode") == RetrieverMode.KB.value else None
    retrieval_mode = config.get("retrieval_mode")
    llm_caps = get_llm_capabilities()
    main_streams = llm_caps.get("streaming", False)
    fallback_streams = get_llm_capabilities(get_fallback=True).get("streaming", False)
    fallback_id = get_fallback_id()
    if main_streams and not fallback_streams and fallback_id:
        logger.warn("Main LLM supports streaming, but fallback LLM does not. Please check your configuration.")
    if llm_caps.get("image_generation") and not config.get("upload_folder"):
        raise ValueError("Image generation requires an upload folder to be set. Please configure one in the webapp settings and restart.")
    headers = dict(request.headers)
    disclaimer = config.get("disclaimer", "")
    ########################################################
    #   TODO: tmp fix while in memory RAG is not used 
    ########################################################
    self_service_embedding_model = config.get("self_service_embedding_model", "TMP")
    ########################################################
    if self_service_embedding_model == "None":
        self_service_embedding_model = None
    user_profile = get_current_user_profile(headers)
    result = {
        "examples": examples,
        "title": title,
        "subtitle": subtitle,
        "language": lang,
        "default_user_language": config.get("default_user_language", "English"),
        "user_settings": config.get("user_profile_settings", []),
        "current_user_profile": user_profile,
        "supported_languages": SUPPORTED_LANGUAGES,
        "input_placeholder": placeholder,
        "project": dataiku_api.default_project_key,
        "upload_folder_id": config.get("upload_folder"),
        "self_service_embedding_model": self_service_embedding_model,
        "feedback_negative_choices": feedback_negative_choices,
        "feedback_positive_choices": feedback_positive_choices,
        "filters_config": filters_config,
        "retrieval_mode": retrieval_mode,
        "retriever_info": get_retriever_info(config),
        "display_sql_query": config.get("display_sql_query", False),
        "sql_retrieval_connection": config.get("sql_retrieval_connection", ""),
        "sql_retrieval_suggested_joins": config.get("sql_retrieval_suggested_joins", []),
        "display_source_chunks": config.get("display_source_chunks", True),
        "llm_capabilities": llm_caps,
        "llm_id": config.get("llm_id", ""),
        "image_generation_llm_id": config.get("image_generation_llm_id", "")
        if config.get("enable_image_generation", False)
        else "",
        "user": get_user(headers),
        "allow_general_feedback": config.get("allow_general_feedback", False),
        "disclaimer": disclaimer,
        "use_custom_rebranding": config.get("use_custom_rebranding", False),
        "custom_theme_name": config.get("custom_theme_name", ""),
        "custom_logo_file_name": config.get("custom_logo_file_name", ""),
        "custom_icon_file_name": config.get("custom_icon_file_name", ""),
        "max_upload_size_mb": config.get("max_upload_size_mb", MAX_UPLOAD_SIZE_MB),
        "max_n_upload_files": config.get("max_n_upload_files", MAX_N_UPLOAD_FILES),
        "image_extensions": list(IMAGE_EXTENSIONS),
        "document_extensions": list(DOCUMENT_EXTENSIONS),
        "max_images_per_user_per_week": config.get("max_images_per_user_per_week", 0)
    }
    return return_ok(data=result)  # type: ignore


@config_blueprint.route("/user/profile", methods=["POST"])
@log_http_request
def update_user_profile() -> str:
    headers = dict(request.headers)
    user = get_user(headers)
    settings = request.get_json()
    new_profile = settings.get("profile")
    try:
        profile = user_profile_sql_manager.get_user_profile(user)
        if profile:
            result = user_profile_sql_manager.update_user_profile(
                user=user, profile=new_profile)
        else:
            result = user_profile_sql_manager.add_user_profile(
                user=user, profile=new_profile)
    except Exception as e:
        logger.error(f"Error updating user profile: {str(e)}")
        return return_ko(message="Error updating user profile")
    return return_ok(data=result)