from __future__ import annotations

import logging
import os
import tempfile
from typing import Any, Dict, List

import dataiku
from dataiku.core.intercom import TicketProxyingDSSClient
from dataiku.customwebapp import get_webapp_config
from dataikuapi.dss.llm import DSSLLM
from dataikuapi.dssclient import DSSClient

logger = logging.getLogger(__name__)


class WebAppConfig:
    def __init__(self) -> None:
        self.__is_setup = False
        self.__webapp_config: Dict = {}
        self.__default_project_key = ""
        self.__db_explorer_folder_paths: List[str] = [""]
        self.__db_local_folder_path = ""
        self.__db_query_timeout_seconds = 60
        self._client = dataiku.api_client()
        self.__llm_id = None
        self.__logging_ds: str | None = None


    def setup(self, webapp_config: Dict | None = None, default_project_key: str | None = None) -> None:
        if self.__is_setup:
            return

        self.__webapp_config = webapp_config if webapp_config else get_webapp_config()
        assert self.__webapp_config

        self.__default_project_key = (
            default_project_key if default_project_key else dataiku.get_custom_variables()["projectKey"]
        )
        assert self.__default_project_key

        logger.info(f"Webapp config is {self.__webapp_config}.")

        # When debugging locally, use a path passed as an environment variable to store files instead of the default /tmp.
        debug_run_folder_path = os.getenv("DEBUG_RUN_FOLDER")

        self.__db_local_folder_path = (
            self.__create_db_folder__(debug_run_folder_path)
            if debug_run_folder_path
            else self.__create_db_folder__(tempfile.TemporaryDirectory().name)
        )

        logger.info(f"Kuzu files are persisted locally at '{self.__db_local_folder_path}'.")

        if self.__webapp_config.get("db_explorer_folder_paths"):
            self.__db_explorer_folder_paths = self.__webapp_config.get("db_explorer_folder_paths") # type: ignore

        if self.__webapp_config.get("db_query_timeout_seconds"):
            self.__db_query_timeout_seconds = int(self.__webapp_config.get("db_query_timeout_seconds"))  # type: ignore

        if self.__webapp_config.get("llm_id"):
            self.__llm_id = self.__webapp_config.get("llm_id")
        
        if self.__webapp_config.get("logging_ds"):
            self.__logging_ds = self.__webapp_config.get("logging_ds")

        self.__is_setup = True

    @property
    def client(self) -> DSSClient | TicketProxyingDSSClient:
        if self._client is None:
            raise Exception("Please set the client before using it.")
        else:
            return self._client

    @client.setter
    def client(self, c: Any):
        raise Exception("If working outside of Dataiku, Client can only be set through the function setup()")

    @property
    def default_project_key(self):
        return self.__default_project_key

    @property
    def db_folder_path(self) -> str:
        return self.__db_local_folder_path

    @property
    def db_query_timeout_seconds(self) -> int:
        return self.__db_query_timeout_seconds

    @property
    def db_explorer_folder_paths(self) -> List[str]:
        return self.__db_explorer_folder_paths

    @property
    def db_explorer_folders(self) -> List[dataiku.Folder]:
        return [dataiku.Folder(path) for path in self.db_explorer_folder_paths]
    
    @property
    def llm_id(self) -> str | None:
        return self.__llm_id

    @property
    def logging_ds(self) -> str | None:
        return self.__logging_ds


    def __create_db_folder__(self, parent_path: str):
        p = os.path.join(parent_path, "graph-instances")

        try:
            os.makedirs("graph-instances")
        except FileExistsError:
            pass

        return p

    def get_llm(self) -> DSSLLM | None:
        if not self.llm_id:
            return None

        client = dataiku.api_client()
        project = client.get_project(self.default_project_key)
        return project.get_llm(self.llm_id)


webapp_config = WebAppConfig()
