"""Alembic migration utilities.

This module provides helper functions for database migrations that support:
- Table name prefixes (via TABLES_PREFIX environment variable)
- Schema-qualified names for PostgreSQL/Snowflake (via DB_SCHEMA environment variable)
"""

import json
import logging
import os
import re
import time
from contextlib import contextmanager
from typing import Callable, Dict

import sqlalchemy as sa
from sqlalchemy.types import Text, TypeDecorator

from alembic import op

# Configure migration logger at INFO level
migration_logger = logging.getLogger("alembic.runtime.migration")
migration_logger.setLevel(logging.INFO)
if not migration_logger.handlers:
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter("[Migration] %(message)s"))
    migration_logger.addHandler(handler)


@contextmanager
def timed_step(step_name: str, logger_instance=None):
    """Context manager to time and log a step.
    
    Usage:
        with timed_step("Step 1/4: Add columns"):
            # do work
    
    Logs:
        [Step 1/4: Add columns] Starting...
        [Step 1/4: Add columns] Completed in 1.23s
    """
    log = logger_instance or migration_logger
    log.info("[%s] Starting...", step_name)
    start = time.perf_counter()
    try:
        yield
        duration = time.perf_counter() - start
        log.info("[%s] Completed in %.2fs", step_name, duration)
    except Exception as e:
        duration = time.perf_counter() - start
        log.error("[%s] Failed after %.2fs: %s", step_name, duration, e)
        raise


class JsonEncoded(TypeDecorator):
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return json.dumps(value)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return json.loads(value)


def is_snowflake() -> bool:
    """Check if current database is Snowflake."""
    conn = op.get_bind()
    return conn.dialect.name == "snowflake"

def is_sqlite() -> bool:
    """Check if current database is SQLite."""
    conn = op.get_bind()
    return conn.dialect.name == "sqlite"

def get_tables_prefix() -> str:
    """Get tables prefix from environment.
    
    Validates the prefix to prevent SQL injection - it must start with a
    letter and may only contain alphanumeric characters and underscores.
    """
    prefix = os.environ.get("TABLES_PREFIX", "") or ""
    PREFIX_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9_]*$")
    # Security: Validate prefix to prevent SQL injection
    if prefix and not PREFIX_RE.match(prefix):
        raise ValueError(
            f"Invalid TABLES_PREFIX '{prefix}': it must start with a letter and contain only alphanumeric characters and underscores"
        )
    return prefix


def get_db_schema() -> str | None:
    """Get database schema from environment."""
    return os.environ.get("DB_SCHEMA", None)


def get_prefixed_table_name(table_name: str) -> str:
    """Get table name with prefix applied."""
    prefix = get_tables_prefix()
    return f"{prefix}{table_name}"


def table_exists(table_name: str) -> bool:
    """Check if a table exists (with prefix applied)."""
    conn = op.get_bind()
    schema = get_db_schema()
    prefixed_name = get_prefixed_table_name(table_name)
    return conn.dialect.has_table(conn, prefixed_name, schema=schema)


def get_qualified_table_name(table_name: str) -> str:
    """Get schema-qualified table name for PostgreSQL/Snowflake (with prefix applied)."""
    schema = get_db_schema()
    prefixed_name = get_prefixed_table_name(table_name)
    if schema:
        return f"{schema}.{prefixed_name}"
    return prefixed_name


def get_fk_prefix() -> str:
    """Get the prefix for foreign key references (schema + tables_prefix)."""
    schema = get_db_schema()
    tables_prefix = get_tables_prefix()
    if schema:
        return f"{schema}.{tables_prefix}"
    return tables_prefix


def create_table(table_name: str, *columns, **kwargs):
    """Create a table with the prefix applied."""
    prefixed_name = get_prefixed_table_name(table_name)
    return op.create_table(prefixed_name, *columns, **kwargs)


def drop_table(table_name: str):
    """Drop a table with the prefix applied."""
    prefixed_name = get_prefixed_table_name(table_name)
    return op.drop_table(prefixed_name)


def batch_alter_table(table_name: str, **kwargs):
    """Batch alter a table with the prefix applied."""
    prefixed_name = get_prefixed_table_name(table_name)
    return op.batch_alter_table(prefixed_name, **kwargs)


def drop_index(index_name: str, table_name: str, **kwargs):
    """Drop an index with the table prefix applied."""
    try:
        prefixed_table = get_prefixed_table_name(table_name)
        return op.drop_index(index_name, table_name=prefixed_table, **kwargs)
    except Exception as e:
        # Log warning but continue migration
        print(f"Warning: Could not drop index {index_name} on table {prefixed_table}: {e}")



def create_index(index_name: str, table_name: str, columns, **kwargs):
    """Create an index with the table and index prefix applied."""
    prefix = get_tables_prefix()
    prefixed_index = f"{prefix}{index_name}" if prefix else index_name
    prefixed_table = get_prefixed_table_name(table_name)
    return op.create_index(prefixed_index, prefixed_table, columns, **kwargs)


def add_column(table_name: str, column):
    """Add a column with the table prefix applied."""
    prefixed_table = get_prefixed_table_name(table_name)
    return op.add_column(prefixed_table, column)


def drop_column(table_name: str, column_name: str):
    """Drop a column with the table prefix applied."""
    prefixed_table = get_prefixed_table_name(table_name)
    return op.drop_column(prefixed_table, column_name)


def rename_table(old_table_name: str, new_table_name: str):
    """Rename a table with the prefix applied to both names."""
    prefixed_old = get_prefixed_table_name(old_table_name)
    prefixed_new = get_prefixed_table_name(new_table_name)
    return op.rename_table(prefixed_old, prefixed_new)


def get_data_to_migrate(
    table_name: str, column_name: str, id_col: str, connection: sa.Connection, transformation: Callable[[str], str]
) -> Dict[int, str]:
    """Get data to migrate and migrate them using a transformation.

    :param table_name: Name of the table (without prefix)
    :param column_name: Name of the column to migrate
    :param id_col: Name of the ID column
    :param connection: Database connection
    :param transformation: Function called to migrate data to its new format
    :return Dict {row id -> row migrated column's data}
    """
    qualified_table = get_qualified_table_name(table_name)
    rows = connection.exec_driver_sql(
        f"SELECT {qualified_table}.{id_col}, {qualified_table}.{column_name} FROM {qualified_table}"
    )
    return {row[0]: transformation(row[1]) for row in rows}


def update_table_with_migrated_data(
    table_name: str, column_name: str, id_col: str, connection: sa.Connection, data: Dict[int, str]
):
    """Update the SQL table using migrated data.

    :param table_name: Name of the table (without prefix)
    :param column_name: Name of the column to update
    :param id_col: Name of the ID column
    :param connection: Database connection
    :param data: Migrated data used to update the table
    """
    qualified_table = get_qualified_table_name(table_name)
    # Use %s for PostgreSQL, ? for SQLite
    dialect_name = connection.dialect.name
    placeholder = "%s" if dialect_name in ("postgresql", "snowflake") else "?"
    for row_id, row_value in data.items():
        connection.exec_driver_sql(
            f"UPDATE {qualified_table} SET {column_name} = {placeholder} WHERE {id_col} = {placeholder}",
            (row_value, row_id)
        )


def get_table_columns(table_name: str, inspector: sa.Inspector) -> list[str]:
    """Get list of column names for a table.
    
    :param table_name: Name of the table (without prefix)
    :param inspector: SQLAlchemy inspector instance
    :return: List of column names
    """
    if not table_exists(table_name):
        return []
    pfx = get_prefixed_table_name(table_name)
    return [col["name"] for col in inspector.get_columns(pfx)]


def get_table_foreign_keys(table_name: str, inspector: sa.Inspector) -> list[str]:
    """Get list of foreign key names for a table.
    
    :param table_name: Name of the table (without prefix)
    :param inspector: SQLAlchemy inspector instance
    :return: List of foreign key constraint names
    """
    if not table_exists(table_name):
        return []
    pfx = get_prefixed_table_name(table_name)
    return [fk["name"] for fk in inspector.get_foreign_keys(pfx)]


def get_table_indexes(table_name: str, inspector: sa.Inspector) -> list[str]:
    """Get list of index names for a table.
    
    :param table_name: Name of the table (without prefix)
    :param inspector: SQLAlchemy inspector instance
    :return: List of index names
    """
    if not table_exists(table_name):
        return []
    pfx = get_prefixed_table_name(table_name)
    return [ix["name"] for ix in inspector.get_indexes(pfx)]
