import logging

import dataiku
from dataiku.base.spark_like import SparkLike
from dataiku.base.sql_dialect import SparkLikeDialect
from dataiku.core import snowflake_utils
import base64
from urllib.parse import urlparse, parse_qsl

try:
    from snowflake.snowpark import Session
    from snowflake.snowpark.functions import col, lit, datediff, call_builtin
except ImportError as e:
    raise Exception("Unable to import Snowpark libraries. Make sure you are using a code-env where Snowpark is installed. Cause: " + str(e))

class DkuSnowparkDialect(SparkLikeDialect):

    def __init__(self):
        SparkLikeDialect.__init__(self)

    def _get_to_dss_types_map(self):
        if self._to_dss_types_map is None:
            self._to_dss_types_map = {
                            'Binary': 'string',
                            'Boolean': 'boolean',
                            'Date': 'dateonly',
                            'String': 'string',
                            'Timestamp': 'date',  # By default we consider it's a TIMESTAMP_TZ
                            'TimestampTz': 'date',
                            'TimestampNtz': 'datetimenotz',
                            'Time': 'string',
                            'Byte': 'tinyint',
                            'Short': 'smallint',
                            'Integer': 'int',
                            'Long': 'bigint',
                            'Float': 'float',
                            'Double': 'double',
                            'Decimal': 'double',
                            'Geography': 'geometry',
                            'Array': 'string',
                            'Map': 'string',
                            'Variant': 'string',
                            'ColumnIdentifier': 'string',
                            'Null': 'string',
                            'Struct': 'string',
                        }
        return self._to_dss_types_map
        
    def allow_empty_schema_after_catalog(self):
        """Whether specifying a table as (catalog, table) is possible"""
        return True
        
    def identifier_quote_char(self):
        """Get the character used to quote identifiers"""
        return '"'
    
    def _column_name_to_sql_column(self, identifier):
        return col(self.quote_identifier(identifier))
    
    def _python_literal_to_sql_literal(self, value, column_type, original_type=None):
        if column_type == 'date':
            return call_builtin("TO_TIMESTAMP_TZ", str(value))
        elif column_type == 'dateonly':
            return call_builtin("TO_DATE", str(value))
        elif column_type == 'datetimenotz':
            return call_builtin("TO_TIMESTAMP_NTZ", str(value))
        else:
            return lit(value)        
    
    def _get_components_from_df_schema(self, df_schema):
        fields = {}
        names = []
        for name in df_schema.names:
            names.append(self.unquote_identifier(name))
        for field in df_schema.fields:
            col_name = self.unquote_identifier(field.name)
            fields[col_name] = {"name":col_name, "datatype":field.datatype}
        return (names, fields)
        
    def _get_datatype_name_from_df_datatype(self, datatype):
        # Use str() to retrieve types used in early Snowpark betas
        datatype_name = str(datatype)
        # Most recent versions of Snowpark use a class instead, retrieve its name and strip the 'Type' suffix
        if datatype_name.endswith(')'):
            extra = datatype_name[datatype_name.find("("):-1]
            datatype_name = type(datatype).__name__
            if datatype_name == 'TimestampType':
                # the timezone indication is inside the ( )
                if "tz=tz" in extra or "tz=ltz" in extra:
                    return 'TimestampTz'
                elif "tz=ntz" in extra:
                    return 'TimestampNtz'
                else:
                    return 'Timestamp'
            elif datatype_name.endswith('Type'):
                datatype_name = datatype_name[0:-4]
            else:
                logging.debug("Possibly unhandled data type: %s" % datatype_name)
        return datatype_name
        
        
# noinspection PyPep8Naming
class DkuSnowpark(SparkLike):
    """
    Handle to create Snowpark sessions from DSS datasets or connections
    """

    def __init__(self):
        SparkLike.__init__(self)
        self._dialect = DkuSnowparkDialect()
        self._connection_type = "Snowflake"

    def _create_session(self, connection_name, connection_info, project_key=None):

        connection_parameters = snowflake_utils.get_snowflake_connection_params(connection_name, connection_info)

        logging.info("Establishing Snowpark session")
        session = Session.builder.configs(connection_parameters).create()
        logging.info("Snowpark session established")

        # Execute post connect statements if any
        if "postConnectStatementsExpandedAndSplit" in connection_parameters and len(connection_parameters["postConnectStatementsExpandedAndSplit"]) > 0:
            for statement in connection_parameters["postConnectStatementsExpandedAndSplit"]:
                logging.info("Executing statement: %s" % statement)
                session.sql(statement).collect()
                logging.info("Statement done")

        session.dss_connection_name = connection_name  # Add a dynamic attribute to the session to recognize its DSS connection later on
        return session
        
    def _split_jdbc_url(self, sf_url):
        if not sf_url.startswith("jdbc:snowflake:"):
            raise ValueError("Invalid JDBC URL. It must start with jdbc:snowflake://")
        sf_url = sf_url[len("jdbc:snowflake:"):]
        url_elements = urlparse(sf_url)
        params = {}
        params['host'] = url_elements.netloc.split(":")[0]
        params['properties'] = []
        result = dict(parse_qsl(url_elements.query))
        for k in result:
            params['properties'].append({'name':k, 'value':result[k]})
        return params

    def _check_dataframe_type(self, df):
        """Check if the dataframe is of the correct type"""
        if not df.__class__.__module__.startswith("snowflake.snowpark"):
            raise ValueError("Dataframe is not a Snowpark dataframe. Use dataset.write_dataframe() instead.")

    def _do_with_column(self, df, column_name, column_value):
        """Add or set a column in the dataframe"""
        return df.withColumn(self._dialect.quote_identifier(column_name), column_value)

    def _get_table_schema(self, schema, connection_params):
        if schema and schema.strip():
            return schema
        return self._get_connection_param(connection_params, "defaultSchema", "schema")

    def _get_table_catalog(self, catalog, connection_params):
        if catalog and catalog.strip():
            return catalog
        return self._get_connection_param(connection_params, "db", "db")
