import logging
import json
import re
import datetime as dt
import pandas as pd
import functools

import dataiku
from dataiku.base.spark_like import SparkLike
from dataiku.base.sql_dialect import SparkLikeDialect
from urllib.parse import urlparse

try:
    from bigframes import Session
    from bigframes.session.clients import ClientsProvider
    from bigframes._config.bigquery_options import BigQueryOptions
    from google.api_core.client_options import ClientOptions
    import bigframes.pandas as bpd
    from google.cloud import bigquery
except ImportError as e:
    raise Exception("Unable to import Bigframes libraries. Make sure you are using a code-env where Bigframes is installed. Cause: " + str(e))

try:
    import google.oauth2 as g_oauth2
    import google.auth as g_auth
except ImportError as e:
    raise Exception("Unable to import Google client libraries. Make sure you are using a code-env where google-api-python-client, google-auth-httplib2, google-auth-oauthlib are installed. Cause: " + str(e))

class DkuBigframesDialect(SparkLikeDialect):

    def __init__(self):
        SparkLikeDialect.__init__(self)

    def _get_to_dss_types_map(self):
        if self._to_dss_types_map is None:
            self._to_dss_types_map = {
                            'Int64': 'bigint',
                            'Float64': 'double',
                            'boolean': 'boolean',
                            'decimal128': 'double',
                            'decimal256': 'double',
                            'binary': 'string',
                            'geometry': 'string', # we don't (yet?) handle geometry columns in BQ
                            'time64': 'string',
                            'Interval': 'string', # shouldn't be seen
                            'Json': 'string', # shouldn't be seen, the lib claims it's converting it to string
                            'Range': 'string' # what's this guy
                        }
        return self._to_dss_types_map
        
    def allow_empty_schema_after_catalog(self):
        """Whether specifying a table as (catalog, table) is possible"""
        return True
        
    def identifier_quote_char(self):
        """Get the character used to quote identifiers"""
        return '`'
    
    def _column_name_to_sql_column(self, identifier):
        return identifier
    
    def _python_literal_to_sql_literal(self, value, column_type, original_type=None):
        # for temporal types, value should always arrive as str here
        try:
            if column_type == 'date':
                value = str(value)
                if value.endswith('Z'):
                    return dt.datetime.strptime(str(value), "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=dt.timezone.utc)
                else:
                    return dt.datetime.strptime(str(value), '%Y-%m-%d %H:%M:%S.%f').replace(tzinfo=dt.timezone.utc)
            elif column_type == 'dateonly':
                return  dt.datetime.strptime(str(value), '%Y-%m-%d').date()
            elif column_type == 'datetimenotz':
                return dt.datetime.strptime(str(value), '%Y-%m-%d %H:%M:%S.%f')
            else:
                return value
        except ValueError as e:
            logging.warning("Unable to parse '%s' as a temporal value, using as string (cause: %s)" % (value, str(e)))
            return value

    def _get_components_from_df_schema(self, dtypes):
        fields = {}
        names = []
        # when there's native partitioning on the table, bigframes puts the column in the dataframe index
        for name in dtypes.index:
            names.append(self.unquote_identifier(name))
        for n, t in zip(dtypes.index, dtypes.values):
            fields[n] = {"name":self.unquote_identifier(n), "datatype":t}
        return (names, fields)

    def _get_dss_type_from_df_datatype(self, datatype):
        # go for the arrow type when possible (simple types don't have it)
        if hasattr(datatype, 'pyarrow_dtype'):
            return self._get_dss_type_from_arrow_datatype(datatype.pyarrow_dtype)
        # else it's a 'pandas' type
        datatype_name = str(datatype)
        m = re.match("^([^\\[\\]()]+)(\\[([^\\[\\]()]+)\\].*)?(\\(([^()]+)\\).*)?", datatype_name)
        if m is None:
            # unparseable dtype, very unexpected...
            logging.debug("Possibly unhandled data type: %s" % datatype_name)
            return 'string'
        base_type = m.group(1)
        if m.group(3) is None and m.group(5) is None:
            # simple type -> use map
            return self._get_to_dss_types_map().get(base_type, 'string')
        else:
            # more complex shit, work more
            annotation = m.group(3) or m.group(5)
            if base_type.startswith('decimal'): # there are at least 2 flavors, decimal128 and decimal256
                return 'double'
            elif base_type == 'date32':
                return 'dateonly'
            elif base_type == 'timestamp':
                if 'tz=' in annotation:
                    return 'date'
                else:
                    return 'datetimenotz'
            elif base_type.startswith("array"):
                return 'array' # can't do better
            elif base_type.startswith("struct"):
                return 'object' # can't do better
            else:
                return 'string'

    def _get_dss_type_from_arrow_datatype(self, datatype):
        # use str conversion and comparision to avoid importing the pyarrow classes
        datatype_name = str(datatype)
        if datatype_name.startswith('string'):
            return {'type':'string', 'originalType':'STRING'}
        elif datatype_name.startswith('bool'):
            return {'type':'boolean', 'originalType':'BOOLEAN'}
        elif datatype_name.startswith('int64'):
            return {'type':'bigint', 'originalType':'INT64'}
        elif datatype_name.startswith('double'):
            return {'type':'double', 'originalType':'FLOAT64'}
        elif datatype_name.startswith('decimal128'):
            return {'type':'double', 'originalType':'DECIMAL'}
        elif datatype_name.startswith('decimal256'):
            return {'type':'double', 'originalType':'BIGDECIMAL'}
        elif datatype_name == 'binary':
            return {'type':'string', 'originalType':'BINARY'}
        elif datatype_name.startswith('date32'):
            return {'type':'dateonly', 'originalType':'DATE'}
        elif datatype_name.startswith('timestamp'):
            if 'tz=' in datatype_name:
                return {'type':'date', 'originalType':'TIMESTAMP'}
            else:
                return {'type':'datetimenotz', 'originalType':'DATETIME'}
        elif datatype_name.startswith('time64'):
            return {'type':'string', 'originalType':'TIME'}
        elif datatype_name.startswith('list<'):
            array_content = self._get_dss_type_from_arrow_datatype(datatype.value_type)
            return {'type':'array', 'originalType':'ARRAY', 'arrayContent':array_content}
        elif datatype_name.startswith('struct<'):
            object_fields = []
            for i in range(0, datatype.num_fields):
                field = datatype.field(i)
                sub = self._get_dss_type_from_arrow_datatype(field.type)
                sub["name"] = field.name
                object_fields.append(sub)
            return {'type':'object', 'originalType':'STRUCT', 'objectFields':object_fields}
        else:
            return {'type':'string'}


# noinspection PyPep8Naming
class DkuBigframes(SparkLike):
    """
    Handle to create Bigframes sessions from DSS datasets or connections
    """

    def __init__(self, session_ordering_mode='partial'):
        """
        Handle to create Bigframe sessions from DSS datasets or connections

        :param session_ordering_mode: ordering mode to use when creating a session. To use BigQuery tables
                                      with require_partition_filter set to true, you need to use the 'partial'
                                      ordering mode. For regular tables, you can use 'strict'. See
                                      https://cloud.google.com/bigquery/docs/use-bigquery-dataframes#partial-ordering-mode
        """
        SparkLike.__init__(self)
        self._dialect = DkuBigframesDialect()
        self._connection_type = "BigQuery"
        self.session_ordering_mode = session_ordering_mode

    # BigFrames doesn't follow the other API on the table() and sql() method names
    def _unquote_full_identifier(self, full_table_name):
        # this read_gbq_table() wants unquoted table names, and splits on '.' :(
        parts = full_table_name.split(".")
        parts = [self._dialect.unquote_identifier(p) for p in parts]
        return '.'.join(parts)
    def _get_dataframe_from_table(self, session, full_table_name, dataset_params, partitions_filters):
        external_partition_field = dataset_params.get("externalPartitionField")
        index_col = []
        if external_partition_field is not None and len(external_partition_field.strip()) > 0:
            index_col = [external_partition_field.strip()]
        return session.read_gbq_table(self._unquote_full_identifier(full_table_name), use_cache=False, index_col=index_col, filters=partitions_filters)
    def _get_dataframe_from_sql(self, session, query):
        return session.read_gbq_query(query, use_cache=False)
    def _write_dataframe_to_table(self, df, full_table_name):
        # bigframes defaults if_exists to "fail" if you specify a destination table
        # Since we create the table if necessary in the backend, if_exists should be "append"
        # Note that before bigframes 2.5.0, the "fail" wasn't actually failing anything
        df.to_gbq(self._unquote_full_identifier(full_table_name), if_exists="append")

    # BigFrames also doesn't follow the other API on exposing a 'schema' field, and sticks to Pandas terms
    def _get_schema_from_dataframe(self, df):
        return df.dtypes

    # Bigframes sticks to a very pandas-like API, so most filtering functions are different
    def _get_union_all(self, dfs):
        return bpd.concat(dfs)

    def _build_eq_partition_filter(self, col_name, value):
        return (col_name, '==', value)
    def _build_ge_and_lt_partition_filter(self, col_name, value_low, value_hig):
        return [(col_name, '>=', value_low), (col_name, '<', value_hig)]
    def _combine_dimension_filters(self, filters):
        return functools.reduce(lambda f1, f2: f1.concat(f2), filters) # dimension filters are all AND
    def _combine_partition_filters(self, filters):
        return filters # bigframes will treat them as OR

    def _create_session(self, connection_name, connection_info, project_key=None):
        # lessen spam in the output
        bpd.options.display.progress_bar = None

        credentials = self._get_credentials(connection_name, project_key, connection_info)
        connection_params = connection_info.get_resolved_params()
        connection_raw_params = connection_info.get("params", {})

        # Setup endpoints in case of parallel universe
        universe_domain = connection_raw_params.get("universeDomain")
        if universe_domain:
            custom_endpoint = connection_raw_params.get("customEndpoint")
            client_endpoints_override = {
                "bqclient": custom_endpoint or f"https://bigquery.{universe_domain}",
                "bqconnectionclient": f"bigqueryconnection.{universe_domain}",
                "bqstoragereadclient": f"bigquerystorage.{universe_domain}",
                "bqstoragewriteclient": f"bigquerystorage.{universe_domain}",
            }
        else:
            client_endpoints_override = {}

        # need to find the location of the dataset, because otherwise jobs (=queries) refuse to run
        project_id = self._get_table_catalog(None, connection_raw_params)
        dataset_id = self._get_table_schema(None, connection_raw_params)
        if dataset_id is None or len(dataset_id) == 0:
            # we're in a pinch, there's no way to get the location from the dataset... Try to get
            # it from the connection properties. Not that if you have a "Location" property defined,
            # then the Simba driver effectively locks you on that connection :(
            logging.warning("No default dataset defined on the BigQuery connection, will use location from the 'Location' property (if any). Add a 'DefaultDataset' to set a default dataset.")
            location = self._get_connection_param(connection_raw_params, "location", "Location") # there's no "location" field, this is just to make the method happy
        else:
            # try getting the dataset (should be permitted, otherwise you're going to have
            # trouble querying it...)
            client_options = ClientOptions(api_endpoint=client_endpoints_override["bqclient"]) if "bqclient" in client_endpoints_override else None
            client = bigquery.Client(project=project_id, credentials=credentials, client_options=client_options)
            dataset = client.get_dataset(dataset_id)
            location = dataset.location

        # check if we have a kms key name in that connection
        kms_key_name = None
        for prop in connection_raw_params.get("properties", []):
            if prop.get("name") == 'KMSKeyName':
                kms_key_name = prop.get("value")

        bq_client_provider = ClientsProvider(project=project_id, credentials=credentials, location=location, client_endpoints_override=client_endpoints_override)
        bq_options = BigQueryOptions(location=location, kms_key_name=kms_key_name, ordering_mode=self.session_ordering_mode) # leaving ordering_mode to strict makes it impossible to open tables with require_partition_filter

        logging.info("Establishing BigQuery session")
        session = Session(context=bq_options, clients_provider=bq_client_provider)
        logging.info("BigQuery session established")

        # Execute post connect statements if any
        if "postConnectStatementsExpandedAndSplit" in connection_params and len(connection_params["postConnectStatementsExpandedAndSplit"]) > 0:
            for statement in connection_params["postConnectStatementsExpandedAndSplit"]:
                logging.info("Executing statement: %s" % statement)
                session.bqclient.query(statement).collect()
                logging.info("Statement done")

        session.dss_connection_name = connection_name  # Add a dynamic attribute to the session to recognize its DSS connection later on
        return session

    def _split_jdbc_url(self, bq_url):
        if not bq_url.startswith("jdbc:bigquery:"):
            raise ValueError("Invalid JDBC URL. It must start with jdbc:bigquery://")
        bq_url = bq_url[len("jdbc:bigquery://"):]
        url_elements = urlparse(bq_url)
        params = {}
        params['host'] = url_elements.netloc.split(":")[0]
        params['properties'] = []
        for kv in url_elements.params.split(";"):
            i = kv.find("=")
            k = kv[:i]
            v = kv[i+1:]
            params['properties'].append({'name':k, 'value':v})
        return params
        

    def _check_dataframe_type(self, df):
        """Check if the dataframe is of the correct type"""
        if not df.__class__.__module__.startswith("bigframes.dataframe"):
            raise ValueError("Dataframe is not a Bigframes dataframe. Use dataset.write_dataframe() instead.")

    def _prepare_dataframe_for_write(self, df):
        # maybe a column is cast aside in the index. If the index has no name, it's because the ordering_mode is
        # not "partial" and BQ has created a special full order
        if hasattr(df, '_has_index') and df._has_index and df.index.name is not None and len(df.index.name) > 0:
            return df.reset_index()
        else:
            return df

    def _do_with_column(self, df, column_name, column_value):
        """Add or set a column in the dataframe"""
        df = df.copy() # otherwise original is modified... There is no 'withColumn()' in bigframes :(
        df[column_name] = column_value
        return df

    def _get_table_schema(self, schema, connection_params):
        if schema and schema.strip():
            return schema
        # note: there is no 'schema' or 'dataset' field, but there can be a DefaultDataset property
        return self._get_connection_param(connection_params, "schema", "DefaultDataset")

    def _get_table_catalog(self, catalog, connection_params):
        if catalog and catalog.strip():
            return catalog
        # note: 'ProjectId' is the name of the Simba driver property
        return self._get_connection_param(connection_params, "projectId", "ProjectId")


    def _get_credentials(self, connection_name, project_key, connection_info):
        """Check if the dataframe is of the correct type"""

        connection_params = connection_info.get_resolved_params()

        if connection_params['authType'] == "KEYPAIR":
            if 'appSecretContent' in connection_params:
                keyRaw = connection_params['appSecretContent']
            elif 'keyPath' in connection_params:
                keyRaw = connection_params['keyPath']
            else:
                raise ValueError("No keypair found in %s connection. Please refer to DSS Service Account Auth documentation." % connection_name)
            try:
                key = json.loads(keyRaw)
                bq_credentials = g_oauth2.service_account.Credentials.from_service_account_info(key)
            except Exception as e:
                # not json? check if it's a path you can read
                # note that we shouldn't arrive here, because this method gets "resolvedParams", and those have the private key as json, not path
                logging.warning("Keypair not resolved as json, trying to use as path. Error: %s" % str(e))
                bq_credentials = g_oauth2.service_account.Credentials.from_service_account_file(key)

            # As of bigframes 2.5.0, we need to manually specify the scopes
            # This line is doing what was being done prior to 2.5.0
            bq_credentials = bq_credentials.with_scopes(["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/bigquery"],
                                                        default_scopes=bq_credentials._default_scopes)
        elif connection_params['authType'] == "ENVIRONMENT":
            bq_credentials, _ = g_auth.default()
        elif connection_params['authType'] == "OAUTH":
            if 'accessToken' not in connection_info['resolvedOAuth2Credential']:
                raise ValueError("No accessToken found in %s connection. Please refer to DSS OAuth2 credentials documentation." % connection_name)
            accessToken = connection_info['resolvedOAuth2Credential']['accessToken']
            def _oauth2_refresh_handler(req, scopes):
                connection = dataiku.api_client().get_connection(connection_name)
                info = connection.get_info(project_key) # this'll grab a new access token if needed
                resolved_creds = info["resolvedOAuth2Credential"]
                return resolved_creds["accessToken"], dt.datetime.fromtimestamp(resolved_creds["expiry"] / 1000)

            bq_credentials = g_oauth2.credentials.Credentials(accessToken, refresh_handler=_oauth2_refresh_handler)

        else:
            raise ValueError("Unsupported authentication type '%s'." % connection_params['authType'])

        return bq_credentials

    def _cast_to_target_types(self, df, dss_schema, qualified_table_id):
        column_names, column_fields = self._dialect._get_components_from_df_schema(df.dtypes)
        # check the actual schema we're inserting into, and add casts as needed.
        # this is based on the assumption that DSS manages the schema, and may
        # have done some type erasure. And BigQuery is known to almost never do
        # implicit cast()
        try:
            # get the table name alone from the qualified table id
            parts = qualified_table_id.split(".")
            table_name = self._dialect.unquote_identifier(parts[-1])
            columns_table = '.'.join(parts[:-1] + ["`INFORMATION_SCHEMA`", "`COLUMNS`"])
            # catalog + schema for table_name will be defined by what's in columns_table, no need to filter table_schema and table_catalog
            tdf = df._session.read_gbq_query("""select column_name, data_type from %s where table_name = '%s' order by ordinal_position""" % (columns_table, table_name))
            target_column_types = {}
            for i, r in tdf.iterrows():
                target_column_types[r["column_name"]] = r["data_type"]
            for column_name in column_names:
                field = column_fields[column_name]
                target_datatype_name = target_column_types.get(column_name)
                if target_datatype_name is None:
                    logging.warning("Unable to find target datatype for %s" % column_name)
                    continue # not a good sign, we're inserting but the output column isn't there...
                df_datatype = self._dialect._get_dss_type_from_df_datatype(field["datatype"])

                datatype_name = df_datatype.get('originalType', '').upper() if isinstance(df_datatype, dict) else str(df_datatype).upper()
                target_datatype_name = target_datatype_name.upper()

                if datatype_name != 'STRING' and target_datatype_name == 'STRING':
                    df[column_name] = df[column_name].astype('string') # that cast should work regardless of original type
                if (datatype_name == 'DATE' or datatype_name == 'TIMESTAMP') and target_datatype_name == 'DATETIME':
                    df[column_name] = df[column_name].astype('timestamp[us][pyarrow]')
                if (datatype_name == 'DATE' or datatype_name == 'DATETIME') and target_datatype_name == 'TIMESTAMP':
                    df[column_name] = df[column_name].astype('timestamp[us, tz=UTC][pyarrow]')
                if (datatype_name == 'DATETIME' or datatype_name == 'TIMESTAMP') and target_datatype_name == 'DATE':
                    df[column_name] = df[column_name].astype('date32[day][pyarrow]')
        except Exception as e:
            logging.warning("Unable to check output schema, inserting as is : %s" % str(e))
        return df
