import dataiku
import pandas as pd
import os
import re
import sqlite3
import math
import sqlparse

from langchain.prompts.prompt import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import CommaSeparatedListOutputParser

DATAFRAME_TO_ANSWER_TEMPLATE = """
Your task is to answer a question based on tabular data. The data is the result of an SQL query. It will come in the following way

question: the question you need to answer
SQL result: The result of the query
query: The query from which the result comes

Don't fabricate information or rely on memory.
If the dataframe has a single column, consider it the answer. Ignore duplicate rows unless specified.
If there is no data, indicate that no matches were found. Do not put something generic, say that the information is not here
The dataframe will be given in the following manner is shown using pd.to_string()

Your answer will follow a precise pattern. 

Start by making thoughts about your data.

Your thoughts must be prefixed by "Thought: "

then give your answer in a precise and concise manner.

Your answer must be prefixed by "Answer: "

Example Input:
question: give the top 3 tallest buildings in New York.
SQL result: "Buildings" : "One World Trade Center", "Central Park Tower", "111 West 57th Street"
query: SELECT Buildings FROM buildin_caracteristics ORDER BY Height LIMIT 3

Example Output:
Thought: The data is thorough and complete. If any information is missing, then it is just not shown
I need to answer the question: give the top 3 tallest buildings in New York.
I have a table with a column buildings and 3 entries. these are likely the three tallest buildings in New York.
It is confirmed by the SQL query.
I have all the information to answer the question.

Answer: the 3 tallest buildings in New York are : "One World Trade Center", "Central Park Tower" ans "111 West 57th Street".


Input:
question: {question}
data: {sql_result}
query: {query}

Output:
Thought:"""

SQL_ERROR_CORRECTION_STR = """
You are an assistant to write SQL queries to answer user questions. You are an expert at {dialect}.
You will be given an input question, a query that tries to answer the question and an error message.
The query is an attempt at answering the question but during the execution of the query, an error occured.
Your task is to rewrite a valid SQL query that answers the question.
Your answer must follow the pattern, start by your explanation, you cannot write anything after the SQL query.

Use the following format:
Question: question that needs to be answered through a SQL query
SQL query: initial SQL query that triggered an error
Error message: error message obtained when running the query
Rewritten SQL query: new valid SQL query that answers the question

Only use the following tables:
{schema}

Question: {question}
SQL query: {query}
Error message: {error_message}
Rewritten SQL query:"""

TABLES_SELECTION_TEMPLATE_STR = """You are an assistant to write SQL queries to answer user questions.
Given the below input question and list of potential SQL tables, output a comma separated list of the table names that may be necessary to answer this question.
Question: {query}
Table Names: {tables}
Relevant Table Names:"""

SQL_GENERATION_TEMPLATE_STR = """You are an assistant to write SQL queries to answer user questions. You are an expert at {dialect}
Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {max_results} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
When testing equality with ids, do not put quotation marks around the id (eg put id = 2 and not id = '2')
Always prefer using ILIKE to = when dealing with words.
Use the following format:
Question: Question here
SQL query: SQL Query to run
Only use the following tables:
{schema}
Question: {query}
SQL query:"""


def get_relevant_tables(project, tag, question, tables_selection_chain):
    """Get the list of the relevant tables to answer the question"""

    dataset_list = []
    for dataset in project.list_datasets():
        if tag in dataset["tags"]:
            dataset = dataiku.Dataset(dataset["name"])
            try:
                dataset_list += [
                    "\n",
                    dataset.name.split(".")[-1],
                    ": ",
                    dataset.get_config()["shortDesc"],
                    "\n",
                    "Columns:",
                    "\n",
                ]
            except KeyError:
                dataset_list += [
                    "\n",
                    dataset.name.split(".")[-1],
                    "\n",
                    "Columns:",
                    "\n",
                ]
            schema = dataset.read_schema()
            for column in schema:
                try:
                    dataset_list += [column["name"], ": ", column["comment"], "\n"]
                except KeyError:
                    dataset_list += [column["name"], "\n"]
    tables = "\n".join(dataset_list)
    return (
        tables_selection_chain(inputs={"query": question, "tables": tables})["text"]
        .replace("'", "")
        .replace('"', "")
        .replace(" ", "")
        .split(",")
    )


def get_query(
    project, tag, question, sql_generation_chain, tables_selection_chain, dialect
):
    """Get the SQL query needed to answer the question"""

    relevant_tables = get_relevant_tables(
        project, tag, question, tables_selection_chain
    )
    dataset_list = []
    for table in relevant_tables:
        dataset = dataiku.Dataset(table)
        try:
            dataset_list += [
                "\n",
                table,
                ": ",
                dataset.get_config()["shortDesc"],
                "\n",
                "Columns:",
                "\n",
            ]
        except KeyError:
            dataset_list += ["\n", table, "\n", "Columns:", "\n"]
        schema = dataset.read_schema()
        for column in schema:
            try:
                dataset_list += [column["name"], ": ", column["comment"], "\n"]
            except KeyError:
                dataset_list += [column["name"], "\n"]
    datasets_schemas = "\n".join(dataset_list)
    answer = sql_generation_chain(
        inputs={
            "query": question,
            "dialect": dialect,
            "schema": datasets_schemas,
            "max_results": "2",
        }
    )["text"]
    return answer, schema, relevant_tables


def execute_sql_query(
    query, project, tag, datasets_restrictions=[]
):
    """Execute the query, return the results or an error message"""
    try:
        datasets = [x for x in project.list_datasets() if tag in x["tags"]]
        dataset_names = [x["name"] for x in datasets]
        try:
            connection = datasets[0]["params"]["connection"]
        except KeyError:
            return "Error: 'table'"
        try:
            query = format_query(
                project, query, dataset_names, datasets_restrictions=datasets_restrictions
            )
        except Exception as e:
            return f"Error: {e} (while formatting the query: {query})"
        executor = dataiku.SQLExecutor2(connection=connection)
        df = executor.query_to_df(query)
        iterator_answer = executor.query_to_iter(query)
        answer_query = "["
        for row in iterator_answer.iter_tuples():
            anwers_query = f"{answer_query}("
            for element in row:
                answer_query = f"{answer_query}{str(element)}, "
            answer_query = f"{answer_query[:-2]}), "
        answer_query = f"{answer_query[:-2]}]"
        return df.to_string()
    except Exception as e:
        return f"Error: {e} (while executing the query)"


def format_query(project, query, dataset_names, datasets_restrictions=[]):
    """Formats the query for it to be executable by the SQLExecutor"""
    pattern_quotation_marks = r"(['\"])(.*?)\1"
    delimiters = r"(\s+|\(|\)|,|;)"
    final_string = ""
    project_key = project.project_key
    tokenized = re.split(delimiters, query)
    columns = []
    string_position = 0

    for name in dataset_names:
        dataset = dataiku.Dataset(name)
        schema = dataset.read_schema()
        for column in schema:
            columns.append(column["name"])
    for token in tokenized:  # Tokens are "words" in the query
        splitted = token.split(".")
        if len(splitted) == 1:  # If the token does not contain a "."
            if splitted[0] in columns and not re.match(
                pattern_quotation_marks, splitted[0]
            ):  # If it's a column that does not already have quotation marks around it, add them
                final_string = f'{final_string}"{splitted[0]}"'

            elif splitted[0] in columns and re.match(
                pattern_quotation_marks, splitted[0]
            ):  # If it's a column that already has quotation marks arount it, do not do anything
                final_string = f"{final_string}{splitted[0]}"
            elif (
                splitted[0] in dataset_names
            ):  # If it is a dataset, get his full name and eventually add the restrictions
                dataset = dataiku.Dataset(splitted[0])
                name = dataset.get_config()["params"]["table"].replace(
                    "${projectKey}", project_key
                )
                name = f'"{name}"'
                for i in range(len(datasets_restrictions)):
                    if (
                        splitted[0] == datasets_restrictions[i]["dataset_name"]
                    ):  # Filter to limit access to authorized data
                        name = f'(SELECT * FROM {name} WHERE {datasets_restrictions[i]["key"]} = {datasets_restrictions[i]["value"]}) AS TABLE_{i}'
                final_string = f'{final_string}{name}'
            else:
                final_string = f"{final_string}{splitted[0]}"
            string_position += len(splitted[0])

        elif splitted[0] in dataset_names and re.match(
            pattern_quotation_marks, splitted[1]
        ):  # If it has a point and it is dataset."column", get the full name of the dataset
            dataset = dataiku.Dataset(splitted[0])
            name = dataset.get_config()["params"]["table"].replace(
                "${projectKey}", project_key
            )
            final_string = f"{final_string}{name}.{splitted[1]}"
            string_position += len(splitted[0]) + len(splitted[1])

        elif splitted[0] in dataset_names and not re.match(
            pattern_quotation_marks, splitted[1]
        ):  # Same but with the need of adding quotation marks (dataset.column)
            dataset = dataiku.Dataset(splitted[0])
            name = dataset.get_config()["params"]["table"].replace(
                "${projectKey}", project_key
            )
            final_string = f'{final_string}{name}."{splitted[1]}"'
            string_position += len(splitted[0]) + len(splitted[1])

        elif splitted[1] in columns and not re.match(
            pattern_quotation_marks, splitted[1]
        ):  # Same but with aliasses
            final_string = f'{final_string}{splitted[0]}."{splitted[1]}"'
            string_position += len(splitted[0]) + len(splitted[1])

        else:  # Any other case
            final_string = f"{final_string}{'.'.join(splitted)}"
            string_position += len(splitted[0]) + len(splitted[1])
    return final_string


def format_query_pandas(query, dataset_names, datasets_restrictions=[]):
    """Format the answer for it to be executable using pd.read_sql_query()"""
    
    # The LIKE operator in SQLite is case-insensitive.
    query = query.replace(" ILIKE ", " LIKE ")
    
    pattern_quotation_marks = r"(['\"])(.*?)\1"
    delimiters = r"(\s+|\(|\)|,|;)"
    final_string = ""
    tokenized = re.split(delimiters, query)
    for token in tokenized:
        separated = token.split(".")
        if len(separated) == 1:
            for i in range(len(datasets_restrictions)):
                if separated[0] == datasets_restrictions[i]["dataset_name"]:
                    separated[
                        0
                    ] = f'(SELECT * FROM {separated[0]} WHERE {datasets_restrictions[i]["id"]} = {customer_id})'
                break
            final_string = f"{final_string}{separated[0]}"
    return final_string


def create_database(table_names):
    """Create the temporary database to execute the sql query with pandas dataframes"""

    conn = sqlite3.connect("test_database")
    conn.create_function("SQRT", 1, math.sqrt)
    conn.create_function("POWER", 2, math.pow)
    conn.commit()
    for table in table_names:
        dataiku.Dataset(table).get_dataframe().to_sql(
            table, conn, index=False, if_exists="replace"
        )
    return conn

def is_valid_query(query):
    parsed = sqlparse.parse(query)[0]
    return not any(
        str(token).upper() in ["UPDATE", "DELETE", "INSERT"]
        for token in parsed.tokens
        if str(token.ttype) == "Token.Keyword.DML"
    )  # No destructive DML Statement   

def get_answer(
    project,
    tag,
    question,
    llm,
    datasets_restrictions=[],
    max_error=2,
    dialect="SQL:2011",
):
    """Provide the answer to the quesiton based on the database. Uses Text2SQL"""
    tables_selection_template = PromptTemplate(
        template=TABLES_SELECTION_TEMPLATE_STR,
        input_variables=["query", "tables"],
        output_parser=CommaSeparatedListOutputParser(),
    )
    sql_generation_template = PromptTemplate(
        template=SQL_GENERATION_TEMPLATE_STR,
        input_variables=["query", "schema", "dialect", "max_results"],
    )
    sql_error_correction_template = PromptTemplate(
        template=SQL_ERROR_CORRECTION_STR,
        input_variables=["schema", "question", "query", "error_message", "dialect"],
    )
    tables_selection_chain = LLMChain(
        llm=llm,
        prompt=tables_selection_template,
        verbose=os.environ.get("DKU_LANGCHAIN_VERBOSE", None) is not None,
    )
    sql_generation_chain = LLMChain(
        llm=llm,
        prompt=sql_generation_template,
        verbose=os.environ.get("DKU_LANGCHAIN_VERBOSE", None) is not None,
    )
    sql_correction_chain = LLMChain(
        llm=llm,
        prompt=sql_error_correction_template,
        verbose=os.environ.get("DKU_LANGCHAIN_VERBOSE", None) is not None,
    )

    query, schema, relevant_tables = get_query(
        project, tag, question, sql_generation_chain, tables_selection_chain, dialect
    )
    valid_query = is_valid_query(query)
    try:
        if valid_query:
            is_sql_possible = True  # tracks if it is possible to use the SQLExecutor
            answer = execute_sql_query(
                query,
                project,
                tag,
                datasets_restrictions=datasets_restrictions
            )
            error_count = 0
        else:
            answer = "Error: You do not have the right to use DELETE, UPDATE or INSERT"
        while (
            answer.startswith("Error") and error_count < max_error
        ):  # If there is an error in the execution of the query, an error correcting loop starts
            
            # The answer is "Error: 'table'" if we cannot execute a SQL query
            # In this case, we need to use Pandas. In the other cases, we try to correct the SQL query
            if not answer.startswith("Error: 'table'"):
                error_count += 1
                query = sql_correction_chain(
                    inputs={
                        "schema": schema,
                        "question": question,
                        "query": query,
                        "error_message": answer,
                        "max_results": "2",
                        "dialect": dialect,
                    }
                )["text"]
                valid_query = is_valid_query(query)

            if not valid_query:
                answer = ("Error: You do not have the right to use DELETE, UPDATE or INSERT")               
            else:
                if is_sql_possible:
                    if answer.startswith("Error: 'table'"):
                        is_sql_possible = False
                        conn = create_database(relevant_tables)
                    else:
                        answer = execute_sql_query(
                            query,
                            project,
                            tag,
                            datasets_restrictions=datasets_restrictions
                        )
                else:
                    try:
                        dataset_names = [
                            x["name"]
                            for x in project.list_datasets()
                            if tag in x["tags"]
                        ]
                        query = format_query_pandas(
                            query, dataset_names, datasets_restrictions
                        )
                        answer = pd.read_sql_query(query, conn).to_string()

                    except sqlite3.DatabaseError as e:
                        answer = str(e)

        llm_chain = LLMChain(
            llm=llm,
            prompt=PromptTemplate.from_template(
                DATAFRAME_TO_ANSWER_TEMPLATE.format(
                    question=question, sql_result=answer, query=query
                )
            ),
        )
        data = llm_chain(
            inputs={"question": question, "sql_result": answer, "query": query}
        )["text"]
        pattern_answer = r"Answer: (.*?)\."
        match = re.search(pattern_answer, data, re.DOTALL)
        if match:
            final_answer = match.group(1)
        else:
            final_answer = data
    except Exception as e:
        final_answer = str(e)
    return final_answer, query
