import dataiku
import sqlparse
import re

from dataiku import SQLExecutor2
from langchain_core.tools import ToolException
from pandasql import sqldf

def execute_sql(table_names: list, sql_query: str) -> str:
    """
    Returns the result of SQL query execution.
    When working with "bool" columns, you filter using the TRUE or FALSE values (in uppercase & NO quotation marks).
    """
    
    try:
        is_sql_table = True # Flag for dataset type
        database = None # Database

        # Check for destructive SQL Operations using the sqlparse library
        parsed = sqlparse.parse(sql_query)
        formatted_query = parsed[0].value

        for token in parsed[0].tokens:
            if str(token.ttype) == "Token.Keyword.DML" and str(token).upper() in ["UPDATE", "DELETE", "INSERT"]:
                # Return dictionary of useful outputs (results of query, original query & formatted query to work with sqldf)
                return {
                    'result' : [{'error': f'The generated SQL Query has an unexpected SQL operation - {str(token).upper()}. Only "SELECT" queries are allowed'}],
                    'formatted_query' : formatted_query,
                    'original_query': sql_query
                }

        # If SQL is good, proceed with trying to run the query to fetch the result
        result = ''

        # Perform actions based on origin of dataset
        for enum, table_name in enumerate(table_names):
            dataset = dataiku.Dataset(table_name)
            dataset.read_schema()

            # Enclosing database objects (e.g. table columns) with double quotes
            for col in dataset.cols:
                col_name = col['name']
                formatted_query = re.sub(fr'"{col_name}"', f'{col_name}', formatted_query) # Prevent column from being quoted twice double quotes
                formatted_query = re.sub(fr"\b{col_name}\b", f'"{col_name}"', formatted_query)

            if dataset.get_config()['type'] in ['UploadedFiles']: #TODO: Add other connection types & handle for big data # Handle for flat files to be stored it as a global variable named df1/df2... (for substitution in SQL query generated)
                globals()[f'df{enum}'] = dataiku.Dataset(table_name).get_dataframe()
                formatted_query = re.sub(fr"\b{table_name.split('.')[0]}\b", f'df{enum}', formatted_query)
                is_sql_table = False

            if dataset.get_config()['type'] in ['PostgreSQL', 'Snowflake']: #TODO: Add other connection types # Handle for sql tables to run the SQL query using SQL executor (rename table to full table name)
                database = dataset.get_config()['params']['connection']
                sql_table_name = re.sub(r"\$\{projectKey\}", f"{dataset.get_config()['projectKey']}", dataset.get_config()['params']['table'])
                formatted_query = re.sub(fr"\b{table_name.split('.')[0]}\b", f'{sql_table_name}', formatted_query)
                is_sql_table = True

        # Generate result from SQL query
        if is_sql_table:
            executor = SQLExecutor2(connection=database)

            query_df = executor.query_to_df(formatted_query)

            # Return dictionary of useful outputs (results of query, original query & formatted query to work with sqldf)
            return {
                'result' : str(query_df.to_dict('records')),
                'formatted_query' : formatted_query,
                'original_query': sql_query
            }

        result = sqldf(formatted_query)

        # Return dictionary of useful outputs (results of query, original query & formatted query to work with sqldf)
        return {
            'result' : str(result.to_dict('records')),
            'formatted_query' : formatted_query,
            'original_query': sql_query
        }
        
    except Exception as e:
        raise ToolException(f"""
'error' : {str(e)},
'error_type' : {type(e)}
""")

def get_table_columns(table_names: list) -> str:
    """
    Returns list of table column names and types in JSON.
    You ALWAYS obtain contextual information about the columns in the table first before writing a SQL query.
    """
    
    try:
        return_message = ''

        for table_name in table_names:

            df = dataiku.Dataset(table_name).get_dataframe(limit=3)

            return_message += f'''
    The tabular data in your file named {table_name} has the following schema:

    {str(df.dtypes)}

    Here are some example rows from the table:

    {df.head(3).to_string()}
    '''

        return return_message

    except Exception as e:
        raise ToolException(f"""
'error' : {str(e)},
'error_type' : {type(e)}
""")