from __future__ import annotations

import os
from urllib.parse import quote_plus


def get_workload_folder_path() -> str:
    from dataiku.core import workload_local_folder

    from backend.utils.local_dev import is_local_dev
    from backend.utils.local_dev_utils import load_local_config

    """Get the path to the workload folder."""
    if is_local_dev():
        print("ATTENTION: Running in local development mode.")
        # Import here to avoid circular imports
        # TODO:  For local development, use a fixed path
        return load_local_config()["db_folder_path"]
    else:
        return workload_local_folder.get_workload_local_folder_path()


DB_SCHEMA = None
TABLES_PREFIX = None


def get_db_schema():
    global DB_SCHEMA
    import dataiku
    from dataiku.customwebapp import get_webapp_config

    params = get_webapp_config()
    storage_type = params.get("storage_type", "LOCAL")
    db_connection = params.get("db_connection")
    if storage_type == "REMOTE":
        # Get connection info from DSS API
        client = dataiku.api_client()
        conn = client.get_connection(db_connection)
        conn_info = conn.get_info()  # 🔒 Requires permissions to read connection details
        conn_type = conn_info["type"].lower()
        conn_params = conn_info.get_params()
        if conn_type == "postgresql":
            print(
                "Setting schema for PostgreSQL %s",
                conn_params.get("namingRule", {}).get("schemaName", "public"),
            )
            DB_SCHEMA = conn_params.get("namingRule", {}).get("schemaName", "public")
            return DB_SCHEMA
        elif conn_type == "snowflake":
            # For Snowflake, get defaultSchema param (default to PUBLIC)
            schema = conn_params.get("defaultSchema", "PUBLIC")
            print(f"Setting schema for Snowflake: {schema}")
            DB_SCHEMA = schema
            return DB_SCHEMA
        # elif conn_type in ("mssql", "sqlserver"):
        #     # For MSSQL, get schema param (default to dbo)
        #     schema = conn_params.get("schema", "dbo")
        #     print(f"Setting schema for MSSQL: {schema}")
        #     DB_SCHEMA = schema
        #     return DB_SCHEMA
        # elif conn_type == "oracle":
        #     # For Oracle, schema is usually the username, but can be set explicitly
        #     schema = conn_params.get("schema") or conn_params.get("user")
        #     print(f"Setting schema for Oracle: {schema}")
        #     DB_SCHEMA = schema
        #     return DB_SCHEMA
    DB_SCHEMA = None
    return None


def get_tables_prefix():
    """Get the optional tables prefix from webapp config.

    Validates the prefix to prevent SQL injection - only allows
    alphanumeric characters and underscores.
    """
    import re

    global TABLES_PREFIX
    from dataiku.customwebapp import get_webapp_config

    params = get_webapp_config()
    prefix = params.get("tables_prefix", "") or ""
    PREFIX_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9_]*$")
    # Security: Validate prefix to prevent SQL injection
    # Only allow alphanumeric characters and underscores
    if prefix and not PREFIX_RE.match(prefix):
        raise ValueError(f"Invalid tables_prefix '{prefix}': only alphanumeric characters and underscores are allowed")

    # Ensure prefix ends with underscore if provided
    if prefix and not prefix.endswith("_"):
        prefix = f"{prefix}_"
    TABLES_PREFIX = prefix
    return prefix


def get_table_args():
    schema = get_db_schema()
    if schema:
        return {"schema": schema}
    else:
        return {}


def get_postgres_db_url(conn_params):
    user = quote_plus(conn_params.get("user", ""))
    password = quote_plus(conn_params.get("password", ""))
    host = conn_params.get("host", "")
    port = conn_params.get("port", 5432)
    dbname = conn_params.get("db", "")
    return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"


def get_snowflake_db_url(conn_params, conn_info):
    # Snowflake connection params
    host = conn_params.get("host", "")
    # Extract account from host (e.g., "xxx.snowflakecomputing.com" -> "xxx")
    account = host.replace(".snowflakecomputing.com", "")

    warehouse = quote_plus(conn_params.get("warehouse", ""))
    dbname = quote_plus(conn_params.get("db", ""))
    schema = quote_plus(conn_params.get("defaultSchema", "PUBLIC"))
    role = conn_params.get("role", "")

    auth_type = conn_params.get("authType", "")

    if auth_type == "OAUTH2_APP":
        # OAuth authentication - get the access token from DSS
        try:
            oauth_cred = conn_info.get_oauth2_credential()
            access_token = quote_plus(oauth_cred["accessToken"])
            # Snowflake OAuth URL format
            # Format: snowflake://<account>/<database>/<schema>?warehouse=...&authenticator=oauth&token=...
            url = f"snowflake://{account}/{dbname}/{schema}?warehouse={warehouse}&database={dbname}&schema={schema}&authenticator=oauth&token={access_token}"
        except (ValueError, KeyError) as e:
            raise ValueError(
                f"Failed to get OAuth2 credentials for Snowflake connection: {e}. "
                "Ensure the connection has 'Allow details readability' enabled for your user/group."
            )
    else:
        # Basic auth (user/password)
        user = quote_plus(conn_params.get("user", ""))
        password = quote_plus(conn_params.get("password", ""))
        if not user or not password:
            raise ValueError("Snowflake connection requires user and password for basic authentication.")
        # Format: snowflake://user:password@account/database/schema?warehouse=...
        url = f"snowflake://{user}:{password}@{account}/{dbname}/{schema}?warehouse={warehouse}&database={dbname}&schema={schema}"

    if role:
        url += f"&role={quote_plus(role)}"

    return url


def get_mysql_db_url(conn_params):
    # MySQL connection params
    user = quote_plus(conn_params.get("user", ""))
    password = quote_plus(conn_params.get("password", ""))
    host = conn_params.get("host", "")
    port = conn_params.get("port", 3306)
    dbname = conn_params.get("db", "")
    # SQLAlchemy MySQL URL: mysql+pymysql://user:password@host:port/dbname
    url = f"mysql+pymysql://{user}:{password}@{host}:{port}/{dbname}"
    return url


def get_db_url():
    import dataiku
    from dataiku.customwebapp import get_webapp_config

    """Get the database URL."""
    params = get_webapp_config()
    storage_type = params.get("storage_type", "LOCAL")
    db_connection = params.get("db_connection")
    if storage_type == "LOCAL":
        LOCAL_DB_PATH = os.path.join(get_workload_folder_path(), "data_store.db")
        # LOCAL_DB_URL = f"sqlite:///{LOCAL_DB_PATH}"
        return f"sqlite:///{LOCAL_DB_PATH}"
    else:
        # Get connection info from DSS API
        client = dataiku.api_client()
        conn = client.get_connection(db_connection)
        conn_info = conn.get_info()  # 🔒 Requires permissions to read connection details
        conn_type = conn_info["type"].lower()
        conn_params = conn_info.get_params()
        # Build DB URL based on connection type
        # URL-encode credentials and host to handle special characters
        if conn_type == "postgresql":
            return get_postgres_db_url(conn_params)
        # elif conn_type == "snowflake":
        #     return get_snowflake_db_url(conn_params, conn_info)
        # elif conn_type == "mysql":
        #     return get_mysql_db_url(conn_params)
        # elif conn_type in ("mssql", "sqlserver"):
        #     # MSSQL connection params
        #     user = quote_plus(conn_params.get("user", ""))
        #     password = quote_plus(conn_params.get("password", ""))
        #     host = conn_params.get("host", "")
        #     port = conn_params.get("port", 1433)
        #     dbname = conn_params.get("db", "")
        #     schema = quote_plus(conn_params.get("schema", "dbo"))
        #     # SQLAlchemy MSSQL URL: mssql+pyodbc://user:password@host:port/dbname?driver=ODBC+Driver+17+for+SQL+Server
        #     driver = quote_plus(conn_params.get("driver", "ODBC Driver 17 for SQL Server"))
        #     url = f"mssql+pyodbc://{user}:{password}@{host}:{port}/{dbname}?driver={driver}"
        #     # Optionally add schema info as a query param (not required for connection, but for clarity)
        #     # url += f"&schema={schema}"
        #     return url
        # elif conn_type == "oracle":
        #     # Oracle connection params
        #     user = quote_plus(conn_params.get("user", ""))
        #     password = quote_plus(conn_params.get("password", ""))
        #     host = conn_params.get("host", "")
        #     port = conn_params.get("port", 1521)
        #     # Prefer service name, fallback to SID
        #     service = conn_params.get("db")
        #     sid = conn_params.get("sid")
        #     if service:
        #         # SQLAlchemy Oracle URL: oracle+cx_oracle://user:password@host:port/?service_name=service
        #         url = f"oracle+cx_oracle://{user}:{password}@{host}:{port}/?service_name={service}"
        #     elif sid:
        #         # SQLAlchemy Oracle URL: oracle+cx_oracle://user:password@host:port/sid
        #         url = f"oracle+cx_oracle://{user}:{password}@{host}:{port}/{sid}"
        #     else:
        #         raise ValueError("Oracle connection requires either a service name or SID.")
        #     return url
        else:
            raise ValueError(f"Unsupported connection type: {conn_type}, ONLY postgresql is supported for now.")
