import logging
from typing import Dict, List, Optional

import dataiku
import pandas as pd
from dataiku import SQLExecutor2
from dataiku.core.dataset import Dataset
from dataiku.sql import Column, Constant, Expression, InlineSQL, SelectQuery, Table, toSQL
from dataiku.sql.translate import Dialects
from dataikuapi.dssclient import DSSClient

TO_UNIQUE: List[str] = ["string"]  # Only working for string columns at the moment

AGG_TYPE = {
    "STRING_AGG": [
        "BigQuery",
        "PostgreSQL",
        "SQLServer"
    ],
    "LISTAGG": [
        "Databricks",
        "Oracle",
        "Redshift",
        "Snowflake"
    ]
}

class SQLManager:
    def __init__(self, datasets: List[str], config: Dict):
        self.config = config
        self.client: DSSClient = dataiku.api_client()
        self.datasets = datasets
        first_dataset = Dataset(self.datasets[0])
        self.executor = SQLExecutor2(dataset=first_dataset)
        # TODO: might be able to do this later
        self.dialect: str = first_dataset.get_location_info().get("info", {}).get("databaseType", "UNKNOWN")
        self.dataset_settings_map = {}
        for dataset_name in self.datasets:
            dataset = self.client.get_default_project().get_dataset(dataset_name)
            self.dataset_settings_map[dataset_name] = dataset.get_settings()
        self.sample_values_map = None


    def format_datasets_description_for_descriptor(self) -> str:
        full_description = ""

        for dataset_name in self.datasets:
            full_description += f"* Dataset name: {dataset_name}\n"
            settings = self.dataset_settings_map[dataset_name]
            description = settings.description or settings.short_description
            if description:
                full_description += f"* Dataset description: {description}\n"
            else:
                logging.warning(f"Dataset {dataset_name} does not have a description")

            if self.config.get("include_column_names_in_descriptor", True):
                full_description += f"* Dataset columns:\n"
                for col in settings.schema_columns:
                    full_description += "  * %s (type: %s)" % (col.get("name", ""), col.get("type", ""))

                    if self.config.get("include_column_descriptions_in_descriptor", True) and col.get("comment"):
                        full_description += ": %s" % col["comment"]
                    full_description += "\n"

            full_description += "\n\n"
        return full_description


    def format_datasets_description_for_decision(self):
        full_description = ""

        for dataset_name in self.datasets:
            full_description += self._format_single_dataset_description_for_llm(dataset_name, None)
            full_description += "\n\n"

        full_description += "__ End of Datasets Description __"
        return full_description


    def format_datasets_description_for_generation(self, tables_and_columns):
        subset_descriptions = ""

        for tc in tables_and_columns:
            dataset_name, selected_columns = tc.get("table_name", ""), tc.get("columns", [])

            if dataset_name not in self.datasets:
                logging.warning("Decision referenced non-existing dataset: %s" % dataset_name)
                continue

            subset_descriptions += self._format_single_dataset_description_for_llm(dataset_name, selected_columns, for_generation=True)
            subset_descriptions += "\n\n"

        subset_descriptions += "__ End of Datasets Description __"
        return subset_descriptions


    def collect_sample_values_if_needed(self):
        if self.config.get("sample_values_strategy", "NONE") == "FROM_DATA":
            self.sample_values_map = {}

            for dataset_name in self.datasets:
                low_cardinality_cats = self.get_low_cardinality_cats(dataset_name)
                self.sample_values_map[dataset_name] = low_cardinality_cats
                logging.info("Storing sample values for %s: %s" % (dataset_name, low_cardinality_cats))

    def columns_in_uppercase(self, dataset: Dataset) -> bool:
        # Currently we know this is only an issue for Snowflake
        # so it doesn't make sense to check for other DB locations
        if self.dialect != Dialects.SNOWFLAKE:
            return False
        try:
            q = SelectQuery().select_from(dataset).select(Column("*")).limit(1)
            query = toSQL(q, dataset=dataset)
            df = self.executor.query_to_df(query)
            return bool(all(c.isupper() for c in df.columns))
        except Exception as e:
            logging.exception(f"Error when trying to determine case of table columns: {e}")
            return False

    def _format_single_dataset_description_for_llm(self, dataset_name, included_column_names, for_generation: bool = False):
        ret = f"* Dataset name: {dataset_name}\n"
        is_uppercase = self.columns_in_uppercase(Dataset(dataset_name))
        settings = self.dataset_settings_map[dataset_name]
        description = settings.description or settings.short_description
        if description:
            ret += f"* Dataset description: {description}\n"
        ret += f"* Columns:\n"
        for col in settings.schema_columns:
            col_name = col.get("name", "")
            if not col_name:
                continue
            include_col = (
                included_column_names is None or col_name in included_column_names or "*" in included_column_names
            )
            if not include_col:
                continue
            col_name = col_name.upper() if is_uppercase and for_generation else col_name
            ret += "  * %s (type: %s)" % (col_name, col.get("type", ""))
            ret += ": %s" % col.get("comment", "")

            if self.sample_values_map is not None and dataset_name in self.sample_values_map:
                sample_values = self.sample_values_map[dataset_name].get(col_name)
                if sample_values:
                    ret += ". Distinct sample values: %s" % sample_values
            ret += "\n"
        return ret

    def get_distinct_df(self, dataset_name) -> Optional[pd.DataFrame]:
        # Find the table for the dataset
        dataset_info = dataiku.Dataset(dataset_name).get_location_info().get("info")
        if not dataset_info:
            raise ValueError("No 'info' for dataiku dataset")
        table_name = dataset_info.get("table")
        if not table_name:
            raise ValueError("No table name found in dataiku dataset location info")
        ast_table = Table(name=table_name, catalog=dataset_info.get("catalog"), schema=dataset_info.get("schema"))

        # Filter columns for which we'll get distinct values
        settings = self.dataset_settings_map[dataset_name]
        cols_to_use = [col for col in settings.schema_columns if col["type"] == "string"]
        # VALUES is a protected term in some dialects so should be avoided
        # Some dialects prefer alias is upper case and some as lower
        value_alias = "VALS" if self.dialect in ["Snowflake", "Oracle"] else "vals"
        is_uppercase = self.columns_in_uppercase(Dataset(dataset_name))

        cardinality_cutoff = int(self.config.get("sample_values_from_data_cardinality_cutoff", "30"))
        logging.debug(
            "Getting low cardinality categories for dataset %s with cutoff %d", dataset_name, cardinality_cutoff
        )

        query = ""
        for idx, col in enumerate(cols_to_use):
            if idx != 0:
                query += """
                UNION ALL
                """
            col_name = col.get("name", "")
            col_name = col_name.upper() if is_uppercase else col_name

            sub_query = SelectQuery()
            sub_query.distinct()
            sub_query.select(Column(col_name), alias=value_alias)
            sub_query.select_from(ast_table)

            select_query = SelectQuery()
            select_query.select(Constant(col_name), alias="COL_NAME")
            if self.dialect in AGG_TYPE["LISTAGG"]:
                select_query.select(InlineSQL(f"LISTAGG({value_alias}, ', ') WITHIN GROUP (ORDER BY {value_alias})"), alias="UNIQUE")
            elif self.dialect in AGG_TYPE["STRING_AGG"]:
                select_query.select(InlineSQL(f"STRING_AGG({value_alias}, ', ')"), alias="UNIQUE")
            else:
                raise ValueError(f"Unknown or unsupported dialect selected: {self.dialect}")
            select_query.select(Expression(Column(value_alias)).count(), alias="COUNT")
            select_query.select_from(sub_query, alias="sub")
            select_query.having(Expression(Expression(Column("*")).count()).lt(Constant(cardinality_cutoff)))
            q = toSQL(select_query, dialect=self.dialect)
            query += f" {q} "

        logging.debug(query)
        df: pd.DataFrame = self.executor.query_to_df(query)
        return df


    def get_low_cardinality_cats(self, dataset_name) -> Dict[str, str]:
        """Returns a dict column name -> comma-separated list of values"""
        df: Optional[pd.DataFrame] = self.get_distinct_df(dataset_name)
        if df is None or df.empty:
            return {}
        if "UNIQUE" not in df.columns:
            raise ValueError("DataFrame must contain a 'UNIQUE' column.")

        return df.set_index('COL_NAME')['UNIQUE'].to_dict()
