from typing import Optional

from common.backend.constants import DEFAULT_MAX_LLM_TOKENS, DEFAULT_TEMPERATURE
from common.backend.utils.config_utils import resolve_webapp_param
from common.backend.utils.dataiku_api import dataiku_api
from dataiku.langchain.dku_llm import DKULLM


class LLM_API_Setup:
    def __init__(self, dataiku_api):
        self.llm: Optional[DKULLM] = None
        self.dataiku_api = dataiku_api
        self.llm_id = self.dataiku_api.webapp_config.get("llm_id")
        use_advanced_llm_parameters = self.dataiku_api.webapp_config.get("show_advanced_settings", False) or False
        self.max_tokens: int = resolve_webapp_param("max_llm_tokens", default_value=DEFAULT_MAX_LLM_TOKENS, advanced_mode_enabled=use_advanced_llm_parameters)
        self.temperature: Optional[float] = resolve_webapp_param("llm_temperature", default_value=DEFAULT_TEMPERATURE, advanced_mode_enabled=use_advanced_llm_parameters)
    
    def get_llm(self)->DKULLM:
        if not self.llm_id:
            raise ValueError("A Dataiku LLM ID must be provided")
        else:
            self.llm = DKULLM(llm_id=self.llm_id, max_tokens=self.max_tokens)
            self.llm.temperature = self.temperature # We can't set a 'None' temperature value in the DKULLM constructor before DSS 14
        return self.llm


llm_setup = LLM_API_Setup(dataiku_api)