"""Step 3: Schema alterations and data migration

=== STATE TRANSITION ===
INITIAL:  Tables exist with:
          - TEXT columns (id, owner, status, etc.)
          - TEXT timestamps with 'Z' suffix format
          - TEXT trace data (uncompressed JSON)
          - Missing foreign keys
          - 'metadata' column in derived_documents
RESULT:   Tables with:
          - String(255) columns for proper indexing
          - DateTime columns with normalized ISO format
          - LargeBinary trace (compressed with zlib)
          - Proper foreign keys (agent_shares→agents, messages→conversations, etc.)
          - 'document_metadata' column (renamed)

DOWNGRADE: Drops foreign keys, reverts column renames
           Does NOT revert column types (backward compatible, safer)
========================

Revision ID: 35647ca7796c
Revises: 34647ca7796c
Create Date: 2025-11-25 16:30:00.000000
"""

import os
import sys
import zlib
from datetime import datetime
from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# Add alembic folder to path for utils import
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import (
    batch_alter_table,
    get_data_to_migrate,
    get_db_schema,
    get_prefixed_table_name,
    get_qualified_table_name,
    get_table_columns,
    get_table_foreign_keys,
    timed_step,
    update_table_with_migrated_data,
)
from utils import (
    migration_logger as logger,
)
from utils import (
    table_exists as _table_exists,
)

# revision identifiers, used by Alembic.
revision: str = "35647ca7796c"
down_revision: Union[str, Sequence[str], None] = "34647ca7796c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

STR_LEN = 255


def get_datetime_type():
    if op.get_bind().dialect.name == "mysql":
        from sqlalchemy.dialects.mysql import DATETIME

        return DATETIME(fsp=6)
    return sa.DateTime()


def get_current_timestamp():
    if op.get_bind().dialect.name == "mysql":
        return sa.text("CURRENT_TIMESTAMP(6)")
    return sa.func.current_timestamp()


def datetime_transformation(value):
    """Normalize datetime format to SQLite-compatible format: 'YYYY-MM-DD HH:MM:SS.ffffff'"""
    if value is None:
        return None
    if isinstance(value, datetime):
        return value.strftime("%Y-%m-%d %H:%M:%S.%f")
    if isinstance(value, str):
        # Normalize 'Z' suffix to '+00:00'
        if value.endswith("Z"):
            value = value[:-1] + "+00:00"
        try:
            dt = datetime.fromisoformat(value)
            return dt.strftime("%Y-%m-%d %H:%M:%S.%f")
        except ValueError:
            # Fallback: try parsing with strptime for edge cases
            for fmt in [
                "%Y-%m-%dT%H:%M:%S.%f%z",
                "%Y-%m-%dT%H:%M:%S%z",
                "%Y-%m-%dT%H:%M:%S.%f",
                "%Y-%m-%dT%H:%M:%S",
                "%Y-%m-%d %H:%M:%S.%f",
                "%Y-%m-%d %H:%M:%S",
            ]:
                try:
                    dt = datetime.strptime(value, fmt)
                    return dt.strftime("%Y-%m-%d %H:%M:%S.%f")
                except ValueError:
                    continue
            logger.warning(f"Could not parse datetime value: {value}")
            return None
    logger.warning(f"Could not parse datetime value: {value}")
    return None


def compress_trace_transformation(value):
    """Compress JSON trace data to zlib binary."""
    if value is None:
        return None
    if isinstance(value, bytes):
        if len(value) >= 2 and value[:2] == b"\x78\x9c":
            return value  # Already compressed
        return zlib.compress(value)
    if isinstance(value, str):
        return zlib.compress(value.encode("utf-8"))
    return None


def upgrade() -> None:
    """Perform schema alterations and data migrations.

    Each operation checks if it's needed before running (idempotent).
    """
    with timed_step("Step 3/4: Schema alterations"):
        bind = op.get_bind()
        inspector = sa.inspect(bind)

        CREATED_AT = "created_at"
        LAST_MODIFIED = "last_modified"
        CREATED_AT_SCHEMA = sa.Column(
            CREATED_AT, get_datetime_type(), nullable=False, server_default=get_current_timestamp()
        )
        LAST_MODIFIED_SCHEMA = sa.Column(
            LAST_MODIFIED, get_datetime_type(), nullable=False, server_default=get_current_timestamp()
        )

        # ========== AGENTS TABLE ==========
        if _table_exists("agents"):
            logger.info("  Processing agents table...")
            columns = get_table_columns("agents", inspector)

            # Check if datetime migration is needed (string format with 'Z') - only if table has data
            needs_datetime_migration = False
            try:
                # First check if table has any rows
                count_result = bind.execute(
                    sa.text(f"SELECT COUNT(*) FROM {get_qualified_table_name('agents')}")
                ).fetchone()
                if count_result and count_result[0] > 0:
                    result = bind.execute(
                        sa.text(f"SELECT created_at FROM {get_qualified_table_name('agents')} LIMIT 1")
                    ).fetchone()
                    if result and result[0] and isinstance(result[0], str):
                        needs_datetime_migration = True
            except Exception as e:
                logger.debug(f"    Skipping migration check: {e}")

            if needs_datetime_migration:
                logger.info("    Migrating datetime columns...")
                created_at_data = get_data_to_migrate("agents", CREATED_AT, "id", bind, datetime_transformation)
                last_modified_data = get_data_to_migrate("agents", LAST_MODIFIED, "id", bind, datetime_transformation)
                published_at_data = get_data_to_migrate("agents", "published_at", "id", bind, datetime_transformation)

                with batch_alter_table("agents") as batch_op:
                    batch_op.alter_column("id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), nullable=False)
                    batch_op.alter_column(
                        "owner", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                    )
                    batch_op.alter_column(
                        "publishing_status",
                        existing_type=sa.TEXT(),
                        type_=sa.Enum(
                            "idle", "publishing", "published", "failed", name="publishingstatusenum", native_enum=False
                        ),
                        existing_nullable=True,
                    )
                    batch_op.alter_column(
                        "publishing_job_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=True
                    )
                    batch_op.drop_column(CREATED_AT)
                    batch_op.add_column(CREATED_AT_SCHEMA)
                    batch_op.drop_column(LAST_MODIFIED)
                    batch_op.add_column(LAST_MODIFIED_SCHEMA)
                    batch_op.drop_column("published_at")
                    batch_op.add_column(sa.Column("published_at", get_datetime_type(), nullable=True))

                update_table_with_migrated_data("agents", CREATED_AT, "id", bind, created_at_data)
                update_table_with_migrated_data("agents", LAST_MODIFIED, "id", bind, last_modified_data)
                update_table_with_migrated_data("agents", "published_at", "id", bind, published_at_data)

        # ========== AGENT_SHARES TABLE ==========
        if _table_exists("agent_shares"):
            logger.info("  Processing agent_shares table...")
            fks = get_table_foreign_keys("agent_shares", inspector)

            with batch_alter_table("agent_shares") as batch_op:
                batch_op.alter_column(
                    "agent_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                )
                batch_op.alter_column(
                    "principal", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                )
                batch_op.alter_column(
                    "principal_type",
                    existing_type=sa.TEXT(),
                    type_=sa.Enum("user", "group", name="principaltypeenum", native_enum=False),
                    existing_nullable=False,
                )
                if "fk_agent_shares_agent_id_agents" not in fks:
                    batch_op.create_foreign_key(
                        "fk_agent_shares_agent_id_agents",
                        get_prefixed_table_name("agents"),
                        ["agent_id"],
                        ["id"],
                        referent_schema=get_db_schema(),
                    )

        # ========== CONVERSATIONS TABLE ==========
        if _table_exists("conversations"):
            logger.info("  Processing conversations table...")
            columns = get_table_columns("conversations", inspector)
            # Check if datetime migration is needed (only if table has data)
            needs_datetime_migration = False
            try:
                # First check if table has any rows
                count_result = bind.execute(
                    sa.text(f"SELECT COUNT(*) FROM {get_qualified_table_name('conversations')}")
                ).fetchone()
                if count_result and count_result[0] > 0:
                    result = bind.execute(
                        sa.text(f"SELECT created_at FROM {get_qualified_table_name('conversations')} LIMIT 1")
                    ).fetchone()
                    if result and result[0] and isinstance(result[0], str):
                        needs_datetime_migration = True
            except Exception as e:
                logger.debug(f"    Skipping migration check: {e}")

            if needs_datetime_migration:
                logger.info("    Migrating datetime columns...")
                created_at_data = get_data_to_migrate(
                    "conversations", CREATED_AT, "conversation_id", bind, datetime_transformation
                )
                last_modified_data = get_data_to_migrate(
                    "conversations", LAST_MODIFIED, "conversation_id", bind, datetime_transformation
                )

                with batch_alter_table("conversations") as batch_op:
                    if "last_updated" in columns:
                        batch_op.drop_column("last_updated")
                    batch_op.alter_column(
                        "conversation_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), nullable=False
                    )
                    batch_op.alter_column(
                        "user_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                    )
                    batch_op.alter_column(
                        "status",
                        existing_type=sa.TEXT(),
                        type_=sa.Enum("active", "deleted", name="statusenum", native_enum=False),
                        existing_nullable=True,
                        existing_server_default=sa.text("'active'"),
                    )
                    batch_op.drop_column(CREATED_AT)
                    batch_op.add_column(CREATED_AT_SCHEMA)
                    batch_op.drop_column(LAST_MODIFIED)
                    batch_op.add_column(LAST_MODIFIED_SCHEMA)

                update_table_with_migrated_data("conversations", CREATED_AT, "conversation_id", bind, created_at_data)
                update_table_with_migrated_data(
                    "conversations", LAST_MODIFIED, "conversation_id", bind, last_modified_data
                )

        # ========== DRAFT_CONVERSATIONS TABLE ==========
        if _table_exists("draft_conversations"):
            logger.info("  Processing draft_conversations table...")

            needs_datetime_migration = False
            try:
                # First check if table has any rows
                count_result = bind.execute(
                    sa.text(f"SELECT COUNT(*) FROM {get_qualified_table_name('draft_conversations')}")
                ).fetchone()
                if count_result and count_result[0] > 0:
                    result = bind.execute(
                        sa.text(f"SELECT created_at FROM {get_qualified_table_name('draft_conversations')} LIMIT 1")
                    ).fetchone()
                    if result and result[0] and isinstance(result[0], str):
                        needs_datetime_migration = True
            except Exception as e:
                logger.debug(f"    Skipping migration check: {e}")

            if needs_datetime_migration:
                logger.info("    Migrating datetime columns...")
                created_at_data = get_data_to_migrate(
                    "draft_conversations", CREATED_AT, "agent_id", bind, datetime_transformation
                )
                last_modified_data = get_data_to_migrate(
                    "draft_conversations", LAST_MODIFIED, "agent_id", bind, datetime_transformation
                )

                with batch_alter_table("draft_conversations") as batch_op:
                    batch_op.alter_column(
                        "agent_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                    )
                    batch_op.alter_column(
                        "user_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                    )
                    batch_op.drop_column(CREATED_AT)
                    batch_op.add_column(CREATED_AT_SCHEMA)
                    batch_op.drop_column(LAST_MODIFIED)
                    batch_op.add_column(LAST_MODIFIED_SCHEMA)

                update_table_with_migrated_data("draft_conversations", CREATED_AT, "agent_id", bind, created_at_data)
                update_table_with_migrated_data(
                    "draft_conversations", LAST_MODIFIED, "agent_id", bind, last_modified_data
                )

        # ========== MESSAGES TABLE ==========
        if _table_exists("messages"):
            logger.info("  Processing messages table...")
            columns = get_table_columns("messages", inspector)
            fks = get_table_foreign_keys("messages", inspector)

            needs_datetime_migration = False
            needs_trace_compression = False

            try:
                # First check if table has any rows
                count_result = bind.execute(
                    sa.text(f"SELECT COUNT(*) FROM {get_qualified_table_name('messages')}")
                ).fetchone()
                if count_result and count_result[0] > 0:
                    # Check datetime migration
                    result = bind.execute(
                        sa.text(f"SELECT created_at FROM {get_qualified_table_name('messages')} LIMIT 1")
                    ).fetchone()
                    if result and result[0] and isinstance(result[0], str):
                        needs_datetime_migration = True

                    # Check if trace needs compression
                    if "trace" in columns:
                        result = bind.execute(
                            sa.text(
                                f"SELECT trace FROM {get_qualified_table_name('messages')} WHERE trace IS NOT NULL LIMIT 1"
                            )
                        ).fetchone()
                        if result and result[0] and isinstance(result[0], str):
                            needs_trace_compression = True
            except Exception as e:
                logger.debug(f"    Skipping migration check: {e}")

            if needs_datetime_migration or needs_trace_compression:
                logger.info("    Migrating datetime and trace columns...")

                created_at_data = {}
                feedback_updated_at_data = {}
                trace_data = {}

                if needs_datetime_migration:
                    created_at_data = get_data_to_migrate("messages", CREATED_AT, "id", bind, datetime_transformation)
                    feedback_updated_at_data = get_data_to_migrate(
                        "messages", "feedback_updated_at", "id", bind, datetime_transformation
                    )

                if needs_trace_compression and "trace" in columns:
                    trace_data = get_data_to_migrate("messages", "trace", "id", bind, compress_trace_transformation)

                with batch_alter_table("messages") as batch_op:
                    batch_op.alter_column("id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), nullable=False)
                    batch_op.alter_column(
                        "conversation_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                    )
                    batch_op.alter_column(
                        "role", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                    )
                    batch_op.alter_column(
                        "feedback_by", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=True
                    )
                    batch_op.alter_column(
                        "status",
                        existing_type=sa.TEXT(),
                        type_=sa.Enum("active", "deleted", name="statusenum", native_enum=False),
                        existing_nullable=True,
                        existing_server_default=sa.text("'active'"),
                    )
                    if needs_datetime_migration:
                        batch_op.drop_column(CREATED_AT)
                        batch_op.add_column(CREATED_AT_SCHEMA)
                        batch_op.drop_column("feedback_updated_at")
                        batch_op.add_column(
                            sa.Column(
                                "feedback_updated_at",
                                get_datetime_type(),
                                nullable=True,
                                server_default=get_current_timestamp(),
                            )
                        )

                    if needs_trace_compression and "trace" in columns:
                        batch_op.drop_column("trace")
                        batch_op.add_column(sa.Column("trace", sa.LargeBinary(), nullable=True))
                    elif "trace" not in columns:
                        batch_op.add_column(sa.Column("trace", sa.LargeBinary(), nullable=True))

                if created_at_data:
                    update_table_with_migrated_data("messages", CREATED_AT, "id", bind, created_at_data)
                if feedback_updated_at_data:
                    update_table_with_migrated_data(
                        "messages", "feedback_updated_at", "id", bind, feedback_updated_at_data
                    )
                if trace_data:
                    update_table_with_migrated_data("messages", "trace", "id", bind, trace_data)

            # Add FK if missing
            if "fk_messages_conversation_id_conversations" not in fks:
                logger.info("    Adding foreign key...")
                with batch_alter_table("messages") as batch_op:
                    batch_op.create_foreign_key(
                        "fk_messages_conversation_id_conversations",
                        get_prefixed_table_name("conversations"),
                        ["conversation_id"],
                        ["conversation_id"],
                        referent_schema=get_db_schema(),
                    )

        # ========== MESSAGE_AGENTS TABLE ==========
        if _table_exists("message_agents"):
            logger.info("  Processing message_agents table...")
            fks = get_table_foreign_keys("message_agents", inspector)

            with batch_alter_table("message_agents") as batch_op:
                batch_op.alter_column(
                    "message_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                )
                batch_op.alter_column(
                    "agent_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), existing_nullable=False
                )

            # Add FK if missing
            if "fk_message_agents_message_id_messages" not in fks:
                logger.info("    Adding foreign key...")
                with batch_alter_table("message_agents") as batch_op:
                    batch_op.create_foreign_key(
                        "fk_message_agents_message_id_messages",
                        get_prefixed_table_name("messages"),
                        ["message_id"],
                        ["id"],
                        referent_schema=get_db_schema(),
                    )

        # ========== DERIVED_DOCUMENTS TABLE ==========
        if _table_exists("derived_documents"):
            columns = get_table_columns("derived_documents", inspector)
            if "metadata" in columns and "document_metadata" not in columns:
                logger.info("  Renaming metadata column in derived_documents...")
                with batch_alter_table("derived_documents") as batch_op:
                    batch_op.alter_column("metadata", new_column_name="document_metadata")

        # ========== MESSAGE_ATTACHMENTS TABLE ==========
        if _table_exists("message_attachments"):
            logger.info("  Processing message_attachments table...")
            columns = get_table_columns("message_attachments", inspector)

            # Check if datetime migration is needed (only if table has data)
            if "updated_at" in columns:
                logger.info("    Migrating datetime columns...")
                updated_at_data = get_data_to_migrate(
                    "message_attachments", "updated_at", "message_id", bind, datetime_transformation
                )
                created_at_data = get_data_to_migrate(
                    "message_attachments", CREATED_AT, "message_id", bind, datetime_transformation
                )

                with batch_alter_table("message_attachments") as batch_op:
                    batch_op.drop_column(CREATED_AT)
                    batch_op.add_column(CREATED_AT_SCHEMA)
                    batch_op.drop_column("updated_at")
                    batch_op.add_column(LAST_MODIFIED_SCHEMA)

                update_table_with_migrated_data(
                    "message_attachments", LAST_MODIFIED, "message_id", bind, updated_at_data
                )
                update_table_with_migrated_data("message_attachments", CREATED_AT, "message_id", bind, created_at_data)

        # ========== PREFERENCES TABLE ==========
        if _table_exists("preferences"):
            logger.info("  Processing preferences table...")

            needs_datetime_migration = False
            try:
                # First check if table has any rows
                count_result = bind.execute(
                    sa.text(f"SELECT COUNT(*) FROM {get_qualified_table_name('preferences')}")
                ).fetchone()
                if count_result and count_result[0] > 0:
                    result = bind.execute(
                        sa.text(f"SELECT created_at FROM {get_qualified_table_name('preferences')} LIMIT 1")
                    ).fetchone()
                    if result and result[0] and isinstance(result[0], str):
                        needs_datetime_migration = True
            except Exception as e:
                logger.debug(f"    Skipping migration check: {e}")

            if needs_datetime_migration:
                logger.info("    Migrating datetime columns...")
                created_at_data = get_data_to_migrate(
                    "preferences", CREATED_AT, "user_id", bind, datetime_transformation
                )
                last_modified_data = get_data_to_migrate(
                    "preferences", LAST_MODIFIED, "user_id", bind, datetime_transformation
                )

                with batch_alter_table("preferences") as batch_op:
                    batch_op.alter_column("user_id", existing_type=sa.TEXT(), type_=sa.String(STR_LEN), nullable=False)
                    batch_op.drop_column(CREATED_AT)
                    batch_op.add_column(CREATED_AT_SCHEMA)
                    batch_op.drop_column(LAST_MODIFIED)
                    batch_op.add_column(LAST_MODIFIED_SCHEMA)

                update_table_with_migrated_data("preferences", CREATED_AT, "user_id", bind, created_at_data)
                update_table_with_migrated_data("preferences", LAST_MODIFIED, "user_id", bind, last_modified_data)


def downgrade() -> None:
    """Revert schema alterations.

    This reverses:
    - Foreign key additions
    - Column renames

    NOTE: Column type changes (String, Enum, DateTime, LargeBinary) are NOT reverted
    because they are backward compatible and reverting them risks data loss.
    """
    logger.info("[Downgrade Step 3] Reverting schema alterations...")

    bind = op.get_bind()
    inspector = sa.inspect(bind)

    def get_columns(table_name: str) -> list:
        if not _table_exists(table_name):
            return []
        pfx = get_prefixed_table_name(table_name)
        return [col["name"] for col in inspector.get_columns(pfx)]

    def get_fks(table_name: str) -> list:
        if not _table_exists(table_name):
            return []
        pfx = get_prefixed_table_name(table_name)
        return [fk["name"] for fk in inspector.get_foreign_keys(pfx)]

    # Remove foreign keys added in Step 3
    logger.info("  Removing foreign keys added in Step 3...")

    if _table_exists("agent_shares"):
        fks = get_fks("agent_shares")
        if "fk_agent_shares_agent_id_agents" in fks:
            logger.info("    Dropping fk_agent_shares_agent_id_agents")
            with batch_alter_table("agent_shares") as batch_op:
                batch_op.drop_constraint("fk_agent_shares_agent_id_agents", type_="foreignkey")

    if _table_exists("messages"):
        fks = get_fks("messages")
        if "fk_messages_conversation_id_conversations" in fks:
            logger.info("    Dropping fk_messages_conversation_id_conversations")
            with batch_alter_table("messages") as batch_op:
                batch_op.drop_constraint("fk_messages_conversation_id_conversations", type_="foreignkey")

    if _table_exists("message_agents"):
        fks = get_fks("message_agents")
        if "fk_message_agents_message_id_messages" in fks:
            logger.info("    Dropping fk_message_agents_message_id_messages")
            with batch_alter_table("message_agents") as batch_op:
                batch_op.drop_constraint("fk_message_agents_message_id_messages", type_="foreignkey")

    # Revert column renames
    if _table_exists("derived_documents"):
        columns = get_columns("derived_documents")
        if "document_metadata" in columns and "metadata" not in columns:
            logger.info("  Reverting document_metadata → metadata in derived_documents")
            with batch_alter_table("derived_documents") as batch_op:
                batch_op.alter_column("document_metadata", new_column_name="metadata")

    # Restore legacy last_updated column for old plugin compatibility
    if _table_exists("conversations"):
        columns = get_columns("conversations")
        if "last_updated" not in columns:
            logger.info("  Re-adding last_updated column in conversations")
            with batch_alter_table("conversations") as batch_op:
                batch_op.add_column(sa.Column("last_updated", sa.Text(), nullable=True))
            if "last_modified" in columns:
                try:
                    bind.execute(
                        sa.text(
                            f"UPDATE {get_qualified_table_name('conversations')} "
                            "SET last_updated = last_modified "
                            "WHERE last_updated IS NULL"
                        )
                    )
                except Exception as e:
                    logger.warning("    Could not backfill last_updated from last_modified: %s", e)

    logger.info("[Downgrade Step 3] Complete")
    logger.info("[Note] Column type changes were NOT reverted for data safety")
