import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple

import dataiku
import sql_tool_prompts  # type: ignore[import]
from dataiku.llm.agent_tools import BaseAgentTool
from dataiku.sql import toSQL
from dataikuapi.dss.llm import DSSLLM, DSSLLMCompletionQuery, DSSLLMCompletionResponse
from json_utils import extract_json  # type: ignore[import]
from run_context import RunContext  # type: ignore[import]
from sql_management import SQLManager  # type: ignore[import]
from sql_query_utils import (  # type: ignore[import]
    replace_tables_in_select_query,
    to_select_query,
    is_dataset_access_error,
)
from utils.messages import TOOL_DECLINED_TO_ANSWER  # type: ignore[import]

PROMPT_DATE_FORMAT = "%Y-%m-%d %H:%M"


class SQLQueryTool(BaseAgentTool):
    def set_config(self, config, plugin_config) -> None:
        self.config = config
        self.max_query_attempts = 3
        self.hard_sql_limit = self.config.get("hard_sql_limit", 200)
        self.llm: DSSLLM = dataiku.api_client().get_default_project().get_llm(self.config["llmId"])

        datasets = self.config.get("datasets", [])
        if not datasets:
            raise ValueError("No datasets provided for SQL query tool.")

        connection = self.config.get("connection", None)
        if not connection:
            raise ValueError("No connection provided for SQL query tool.")

        self.sql_manager = SQLManager(datasets, self.config)


# OpenAI tool desc length max is 1024 :'(
# meaning that OpenAI enforces a maximum description length of 1024 characters for tools
# (i.e., tools defined via the tools or function schema in the Chat Completions API).
    def get_descriptor(self, tool) -> Dict[str, Any]:
        all_descriptions =  self.sql_manager.format_datasets_description_for_descriptor()

        return {
            "description": """Given a text query explaining the question to respond, this tool will:

1. Determine whether the datasets are likely to be able to respond the question
2. Generate a SQL query 
3. Execute the SQL query
4. Return the results

The question must be responded to with a single SQL query.

If the tool determines that the question cannot be responded to, it will return an error indicating so.

Here are the datasets and columns that the tool has access to:

%s""" % all_descriptions, # Cannot use f-string here as it is not supported in the tool descriptor
            "inputSchema" : {
                "$id": "https://dataiku.com/agents/tools/scrape/input",
                "title": "Input for the SQL querying tool",
                "type": "object",
                "properties" : {
                    "question" : {
                        "type": "string",
                        "description": "The question to respond"
                    }
                },
                "required": ["question"]
            }
        }


    def is_query_required(self, context: RunContext) -> Dict[str, Any]:
        with context.trace.subspan("Deciding whether SQL tool can answer the question") as subspan:
            decision_completion = self.llm.new_completion()
            decision_completion.with_message(sql_tool_prompts.DECISION_TO_USE_SYSTEM_PROMPT, "system")
            decision_completion.with_message(
                sql_tool_prompts.DECISION_TO_USE_USER_PROMPT_FMT.format(
                    todays_datetime=datetime.now().strftime(PROMPT_DATE_FORMAT),
                    datasets_descriptions=self.sql_manager.format_datasets_description_for_decision(),
                ),
                "user",
            )

            try:
                decision_completion.with_message(message=context.question, role="user")
                decision_response = decision_completion.execute()
                subspan.append_trace(decision_response.trace)

                if not decision_response.text:
                    raise Exception(f"Decision response error: {decision_response.errorMessage if decision_response.errorMessage else 'Response did not contain text'}")

                logging.info(f"Decision result: {decision_response.text}")
                decision_result: Dict[str, Any] = extract_json(decision_response.text)

                if not decision_result.get("tables_and_columns"):
                    subspan.outputs["decision"] = "no"
                    context.sources["items"].append({"type": "INFO", "textSnippet": "Declined to answer"})
                else:
                    subspan.outputs["decision"] = "yes"
                    context.sources["items"].append(
                        {"type": "INFO", "textSnippet": f"Decided to use {decision_result['tables_and_columns']}" }
                    )

                return decision_result
            except Exception as e:
                msg = f"Error while trying to decide whether to use SQL tool: {e}"
                # TODO: is this required?
                context.sources["items"].append( {"type": "ERROR", "textSnippet": msg})
                raise Exception(msg) 


    @staticmethod
    def format_previous_sql_errors(
        previous_sql_errors: str, json_query: Dict[str, Any], sql_query: str, error: str
    ) -> str:
        previous_sql_errors += sql_tool_prompts.QUERY_ERROR_FORMAT.format(
            json_query=json_query, sql_query=sql_query, error=error
        )

        return previous_sql_errors


    def get_query_completion(self, initial_user_prompt: str) -> DSSLLMCompletionQuery:
        completion: DSSLLMCompletionQuery = self.llm.new_completion()
        completion.with_message(sql_tool_prompts.QUERY_BUILDING_SYSTEM_PROMPT, "system")
        completion.with_message(initial_user_prompt, "user")

        return completion

    def get_query_fix_completion(self, initial_user_prompt: str, previous_sql_errors: str) -> DSSLLMCompletionQuery:
        completion: DSSLLMCompletionQuery = self.llm.new_completion()
        completion.with_message(sql_tool_prompts.QUERY_REPAIR_SYSTEM_PROMPT, role="system")
        completion.with_message(
            sql_tool_prompts.QUERY_REPAIR_USER_PROMPT.format(
                previous_user_prompt=initial_user_prompt, formatted_errors=previous_sql_errors
            ),
            role="user",
        )

        return completion


    def get_initial_prompt(self, decision_result: Dict[str, Any]) -> str:
        tables_and_columns = decision_result["tables_and_columns"]
        chosen_tables_and_columns = self.sql_manager.format_datasets_description_for_generation(tables_and_columns)
        return str(sql_tool_prompts.QUERY_BUILDING_USER_PROMPT_FMT.format(
            todays_datetime=datetime.now().strftime(PROMPT_DATE_FORMAT),
            chosen_tables_and_columns=chosen_tables_and_columns,
            decision_to_use_justification=decision_result["justification"],
        ))

    def generate_query(
        self, context: RunContext, decision_result: Dict[str, Any], previous_sql_errors: str = ""
    ) -> Tuple[Optional[Dict[str, Any]], str]:
        completion: DSSLLMCompletionQuery
        text = ""
        subspan_identifier = "Generating SQL query " if not previous_sql_errors else "Retrying SQL query generation"
        with context.trace.subspan(subspan_identifier) as subspan:
            initial_system_prompt = self.get_initial_prompt(decision_result)
            if not previous_sql_errors:
                completion = self.get_query_completion(initial_system_prompt)
            else:
                completion = self.get_query_fix_completion(initial_system_prompt, previous_sql_errors)
            if "additionalInformation" in self.config: # TODO: the name additionalInformation is too vague
                completion.with_message(self.config["additionalInformation"], "user")
            completion.with_message(context.question, role="user")
            try:
                resp: DSSLLMCompletionResponse = completion.execute()
                subspan.append_trace(resp.trace)
                if not resp.text:
                    raise Exception(f"Generation response error: {resp.errorMessage if resp.errorMessage else 'Response did not contain text'}")
                text = str(resp.text)
                logging.info(f"Query building result: {text}")
                query_building_result: Dict[str, Any] = extract_json(text)

                return query_building_result, previous_sql_errors
            except Exception as e:
                msg = f"Error while trying to generate SQL query: {e}. Response: {text}"
                logging.exception(msg)
                previous_sql_errors = self.format_previous_sql_errors(
                    previous_sql_errors, {}, "", msg
                )
                context.sources["items"].append(
                    {"type": "ERROR", "textSnippet": msg})

                return None, previous_sql_errors

    def _records_df_as_artifact_part(self, df) -> Dict[str, Any]:
        max_records_for_artifact = int(self.config.get("max_records_for_artifact", -1))
        logging.debug(f"max_records_for_artifact is set to: {max_records_for_artifact}")

        return {
            "type": "RECORDS",
            "records": {
                "columns": df.columns.to_list(),
                "data": df[:max_records_for_artifact].values.tolist() if max_records_for_artifact != -1 else df.values.tolist()
            }
        }

    def run_sql_query(
        self,
        context: RunContext,
        query_build_result: Dict[str, Any],
        previous_sql_errors: str
    ) -> Tuple[str, Optional[List[Dict[str, Any]]], Optional[Dict[str, Any]], str]:
        with context.trace.subspan("Executing SQL query") as subspan:
            select_query = to_select_query(query_build_result, hard_sql_limit=self.hard_sql_limit)
            logging.debug(f"Initial AST: {select_query}")
            replace_tables_in_select_query(select_query, self.config["datasets"])
            logging.debug(f"AST with tables replaced: {select_query}")
            sql_query: str = ""
            try:
                # Execution can fail in 2 ways:
                # 1. The AST cannot be parse and cause an error
                sql_query = str(toSQL(select_query, dialect=self.sql_manager.dialect))
                logging.debug(f"SQL query: {sql_query}")
                query_source_item = {"type": "GENERATED_SQL_QUERY", "performedQuery": sql_query}
                # 2. The SQL query can be executed and returns an error
                df = self.sql_manager.executor.query_to_df(query=sql_query)
                df.fillna("", inplace=True)
                records: List[Dict[str, Any]] = df.to_dict("records")

                records_source_item = self._records_df_as_artifact_part(df)
                subspan.attributes["sqlQuery"] = sql_query
                context.sources["items"].extend([query_source_item, records_source_item])
                return sql_query, records, records_source_item, previous_sql_errors
            except Exception as e:
                error_message = str(e)
                if is_dataset_access_error(error_message):
                    raise Exception(f"You do not have access to the connection or datasets required to answer this question. ({error_message})")

                logging.exception(f"Exception when attempting to execute SQL query: {e}")
                previous_sql_errors = self.format_previous_sql_errors(previous_sql_errors, query_build_result, sql_query, error_message)
                context.sources["items"].append(
                    {"type": "ERROR", "textSnippet": f"Error executing SQL query: {e}"}
                )
                return sql_query, None, None, previous_sql_errors

    def synthesize_final_answer(
        self,
        context: RunContext,
        decision_result: Dict[str, Any],
        sql_query: str,
        records: List[Dict[str, Any]],
        records_source_item: Optional[Dict[str, Any]]
    ) -> Tuple[str, Optional[List[Any]]]:

        artifact = {
            "name": "Records",
            "description": "Records generated by the query",
            "parts": [records_source_item]
        }

        if self.config.get("return_value_mode", "AUTO") == "ARTIFACT_ONLY":
            return ("The response to the question has been directly provided to the user separately. %d records were returned" % len(records), [artifact])

        with context.trace.subspan("Synthesizing final answer") as subspan:
            logging.debug("Synthesizing final answer")
            final_completion = self.llm.new_completion()
            if self.config.get("return_value_mode", "AUTO") == "AUTO":
                final_completion.with_message(sql_tool_prompts.FINAL_ANSWER_AUTO_SYSTEM_PROMPT, "system")
            else:
                final_completion.with_message(sql_tool_prompts.FINAL_ANSWER_SYSTEM_PROMPT, "system")

            if "additionalInformation" in self.config: # TODO: this name is too vague
                final_completion.with_message(self.config["additionalInformation"], "user")
            user_prompt = sql_tool_prompts.FINAL_ANSWER_USER_PROMPT_FMT.format(
                user_question=context.question,
                chosen_tables_and_columns=decision_result["tables_and_columns"],
                sql_query=sql_query,
                records=records,
            )
            final_completion.with_message(user_prompt, "user")

            try:
                final_resp = final_completion.execute()
                # if not final_resp.text:
                #     error_msg = final_resp.errorMessage if final_resp.errorMessage else "Response did not contain text"
                #     raise Exception(error_msg)
                subspan.append_trace(final_resp.trace)


                if self.config.get("return_value_mode", "AUTO") == "AUTO":
                    decision_result = extract_json(final_resp.text)

                    logging.info("Decision on the final query: %s" % decision_result)

                    text: str = decision_result.get("text_answer", "")
                    return_records = decision_result.get("return_records", False)

                    # Bad answer, return just the artifact
                    if not text and not return_records:
                        text = "The response to the question has been directly provided to the user separately. %d records were returned" % len(records)
                        return_records = True

                    if return_records:
                        return (text, [artifact])
                    else:
                        return (text, [])
                else:
                    text = str(final_resp.text)
                    logging.info(f"Final response: {text}")

                    if self.config.get("return_value_mode", "AUTO") == "BOTH":
                        return (text, [artifact])
                    else:
                        return (text, [])
            except Exception as e:
                msg = f"Error while trying to synthesize final answer: {e}"
                logging.exception(msg)
                context.sources["items"].append({"type": "ERROR", "textSnippet": msg})
                raise Exception(msg)

    def invoke(self, input: Dict[str, Any], trace) -> Dict[str, Any]:
        self.sql_manager.collect_sample_values_if_needed()

        if not (question := input.get("input", {}).get("question")):
            raise ValueError("Input must contain a 'question' field and the value must not be empty.")

        context = RunContext(question, trace)
        decision_result = self.is_query_required(context)

        # If `tables_and_columns` value is `None`, this means that the tool decided
        # that no query is needed or it can't answer the user question.
        if not decision_result.get("tables_and_columns"):
            justification = decision_result.get("justification", "No justification provided")

            return {
                "output": TOOL_DECLINED_TO_ANSWER.format(justification=justification),
                "sources": [context.sources]
            }

        previous_sql_errors = ""
        sql_query = ""
        records = None

        for attempt in range(1, self.max_query_attempts + 1):
            logging.info(f"Attempt {attempt} to generate SQL query")
            query_build_result, previous_sql_errors = self.generate_query(context, decision_result, previous_sql_errors)

            if query_build_result is None:
                logging.error(f"Failed to generate a valid SQL query on attempt {attempt}. Previous errors: {previous_sql_errors}")
                continue

            enduser_sql_execution = self.config.get("enduser_sql_execution", None)
            user_caller_ticket = False

            if enduser_sql_execution == "enduser":
                if "dkuCallerTicket" not in input["context"]:
                    raise Exception("enduser_sql_execution set to 'enduser' but no 'dkuCallerTicket' found in input context")
                user_caller_ticket = True
            elif enduser_sql_execution == "enduser_available":
                if "dkuCallerTicket" in input["context"]:
                    user_caller_ticket = True

            if user_caller_ticket:
                from dataiku.core.intercom import TicketImpersonationContext
                logging.info(f"SQL Execution: TicketImpersonation used")
                with TicketImpersonationContext(input["context"]["dkuCallerTicket"]) as _:
                    sql_query, records, records_source_item, previous_sql_errors = self.run_sql_query(context, query_build_result, previous_sql_errors)                    
            else:
                logging.info(f"SQL Execution: regular")
                sql_query, records, records_source_item, previous_sql_errors = self.run_sql_query(context, query_build_result, previous_sql_errors)

            if records is not None:
                break
            if attempt >= self.max_query_attempts:
                raise Exception(f"Failed to generate a valid SQL query after {self.max_query_attempts} attempts.")

        final_response, artifacts = self.synthesize_final_answer(
            context, decision_result, sql_query, records if records is not None else [], records_source_item
        )

        return {
            "output": final_response,
            "sources": [context.sources],
            "artifacts": artifacts
        }
