import logging
import random
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from pandas.api.types import is_bool_dtype, is_datetime64_any_dtype, is_numeric_dtype
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from torch.utils.data import DataLoader, TensorDataset

import dataiku


logger = logging.getLogger(__name__)


@dataclass
class CVAEOversampleConfig:
    target_column: str
    random_seed: Optional[int]
    latent_dim: int
    hidden_dims: List[int]
    dropout: float
    max_categories: int
    max_rows: int
    epochs: int
    batch_size: int
    learning_rate: float
    weight_decay: float
    kl_beta: float
    kl_warmup_epochs: int
    max_synthetic_multiplier: float

    @classmethod
    def from_recipe_config(cls, config):
        config = config or {}
        target_column = config.get("target_column")
        if not target_column:
            raise ValueError("target_column is required")

        return cls(
            target_column=target_column,
            random_seed=config.get("random_seed", 42),
            latent_dim=int(config.get("latent_dim", 16)),
            hidden_dims=[int(config.get("hidden_dim_1", 256)), int(config.get("hidden_dim_2", 128))],
            dropout=float(config.get("dropout", 0.1)),
            max_categories=int(config.get("max_categories", 0) or 0),
            max_rows=int(config.get("max_rows", 0) or 0),
            epochs=int(config.get("epochs", 50)),
            batch_size=int(config.get("batch_size", 512)),
            learning_rate=float(config.get("learning_rate", 0.001)),
            weight_decay=float(config.get("weight_decay", 0.00001)),
            kl_beta=float(config.get("kl_beta", 0.5)),
            kl_warmup_epochs=int(config.get("kl_warmup_epochs", 10)),
            max_synthetic_multiplier=float(config.get("max_synthetic_multiplier", 3.0)),
        )


class CVAE(nn.Module):
    def __init__(self, x_dim, y_dim, num_dim, cat_dim, latent_dim, hidden_dims, dropout):
        super().__init__()
        self.num_dim = num_dim
        self.cat_dim = cat_dim

        enc_layers = []
        in_dim = x_dim + y_dim
        for h in hidden_dims:
            enc_layers.append(nn.Linear(in_dim, h))
            enc_layers.append(nn.ReLU())
            enc_layers.append(nn.Dropout(dropout))
            in_dim = h
        self.encoder = nn.Sequential(*enc_layers)
        self.mu = nn.Linear(in_dim, latent_dim)
        self.logvar = nn.Linear(in_dim, latent_dim)

        dec_layers = []
        in_dim = latent_dim + y_dim
        for h in reversed(hidden_dims):
            dec_layers.append(nn.Linear(in_dim, h))
            dec_layers.append(nn.ReLU())
            dec_layers.append(nn.Dropout(dropout))
            in_dim = h
        self.decoder = nn.Sequential(*dec_layers)

        self.num_out = nn.Linear(in_dim, num_dim) if num_dim > 0 else None
        self.cat_out = nn.Linear(in_dim, cat_dim) if cat_dim > 0 else None

    def encode(self, x, y_onehot):
        h = self.encoder(torch.cat([x, y_onehot], dim=1))
        return self.mu(h), self.logvar(h)

    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y_onehot):
        h = self.decoder(torch.cat([z, y_onehot], dim=1))
        num_pred = self.num_out(h) if self.num_out is not None else None
        cat_logits = self.cat_out(h) if self.cat_out is not None else None
        return num_pred, cat_logits

    def forward(self, x, y_onehot):
        mu, logvar = self.encode(x, y_onehot)
        z = self.reparameterize(mu, logvar)
        num_pred, cat_logits = self.decode(z, y_onehot)
        return num_pred, cat_logits, mu, logvar


def set_seed(seed):
    if seed is None:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def make_one_hot_encoder():
    try:
        return OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    except TypeError:
        return OneHotEncoder(handle_unknown="ignore", sparse=False)


def split_columns(df, target_col):
    feature_cols = [c for c in df.columns if c != target_col]
    if not feature_cols:
        raise ValueError("No feature columns found after removing target column.")

    numeric_cols = []
    categorical_cols = []
    for col in feature_cols:
        if is_datetime64_any_dtype(df[col]):
            df[col] = df[col].astype(str)
            categorical_cols.append(col)
        elif is_bool_dtype(df[col]):
            categorical_cols.append(col)
        elif is_numeric_dtype(df[col]):
            numeric_cols.append(col)
        else:
            categorical_cols.append(col)

    return feature_cols, numeric_cols, categorical_cols


def build_preprocessor(numeric_cols, categorical_cols):
    transformers = []
    if numeric_cols:
        num_pipe = Pipeline(
            steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())]
        )
        transformers.append(("num", num_pipe, numeric_cols))
    if categorical_cols:
        cat_pipe = Pipeline(
            steps=[("imputer", SimpleImputer(strategy="most_frequent")), ("onehot", make_one_hot_encoder())]
        )
        transformers.append(("cat", cat_pipe, categorical_cols))

    return ColumnTransformer(transformers=transformers, remainder="drop")


def limit_categories(df, categorical_cols, max_categories):
    if not categorical_cols or max_categories <= 0:
        return df
    df = df.copy()
    for col in categorical_cols:
        counts = df[col].value_counts(dropna=False)
        top = counts.index[:max_categories]
        df[col] = df[col].where(df[col].isin(top), "__OTHER__")
    return df


def drop_empty_features(df, feature_cols, numeric_cols, categorical_cols):
    empty_cols = [c for c in feature_cols if df[c].isna().all()]
    if empty_cols:
        logger.info("Dropping %d all-null feature columns: %s", len(empty_cols), empty_cols)
        df = df.drop(columns=empty_cols)
        feature_cols = [c for c in feature_cols if c not in empty_cols]
        numeric_cols = [c for c in numeric_cols if c not in empty_cols]
        categorical_cols = [c for c in categorical_cols if c not in empty_cols]
    return df, feature_cols, numeric_cols, categorical_cols


def _extract_dims(preprocessor, numeric_cols, categorical_cols, x_dim):
    if numeric_cols:
        num_imputer = preprocessor.named_transformers_["num"].named_steps["imputer"]
        num_dim = int(len(num_imputer.statistics_))
    else:
        num_dim = 0

    if categorical_cols:
        ohe = preprocessor.named_transformers_["cat"].named_steps["onehot"]
        cat_group_sizes = [len(cats) for cats in ohe.categories_]
        cat_dim = int(np.sum(cat_group_sizes))
    else:
        cat_group_sizes = []
        cat_dim = 0

    if num_dim + cat_dim != x_dim:
        logger.warning(
            "Preprocessor output dim (%d) does not match num_dim (%d) + cat_dim (%d); adjusting num_dim.",
            x_dim,
            num_dim,
            cat_dim,
        )
        num_dim = max(0, x_dim - cat_dim)

    return num_dim, cat_dim, cat_group_sizes


def _prepare_target(df, target_col):
    target_series = df[target_col]
    if target_series.isna().any():
        before = len(df)
        df = df[target_series.notna()].copy()
        logger.info("Dropped %d rows with null target", before - len(df))

    unique_vals = pd.unique(df[target_col])
    if len(unique_vals) != 2:
        raise ValueError("Target column must have exactly 2 distinct values; got %d" % len(unique_vals))

    counts = df[target_col].value_counts()
    if counts.nunique() == 1:
        minority_value = unique_vals[0]
        majority_value = unique_vals[1]
        logger.info("Target is already balanced; no synthetic rows will be generated")
    else:
        minority_value = counts.idxmin()
        majority_value = counts.idxmax()

    target_encoded = df[target_col].map({minority_value: 1, majority_value: 0})
    return df, target_encoded, minority_value, majority_value


def _build_numeric_stats(df, numeric_cols):
    numeric_stats = {}
    for col in numeric_cols:
        series = df[col].dropna()
        if series.empty:
            continue
        numeric_stats[col] = {
            "min": float(series.min()),
            "max": float(series.max()),
            "single_value": float(series.iloc[0]) if series.nunique(dropna=True) == 1 else None,
            "all_int": bool((series % 1 == 0).all()),
        }
    return numeric_stats


def _train_model(X_train, y_train, cfg, num_dim, cat_dim, latent_dim, hidden_dims, dropout):
    x_dim = X_train.shape[1]
    y_dim = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    X_train_t = torch.tensor(X_train, dtype=torch.float32)
    y_train_t = torch.tensor(y_train.values, dtype=torch.long)

    train_loader = DataLoader(
        TensorDataset(X_train_t, y_train_t), batch_size=cfg.batch_size, shuffle=True, drop_last=False
    )

    model = CVAE(
        x_dim=x_dim,
        y_dim=y_dim,
        num_dim=num_dim,
        cat_dim=cat_dim,
        latent_dim=latent_dim,
        hidden_dims=hidden_dims,
        dropout=dropout,
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    mse = nn.MSELoss(reduction="mean")
    bce_logits = nn.BCEWithLogitsLoss(reduction="mean")

    def one_hot_y(y_long):
        return F.one_hot(y_long, num_classes=y_dim).float()

    logger.info("Starting CVAE training")
    model.train()
    for epoch in range(1, cfg.epochs + 1):
        kl_weight = cfg.kl_beta * min(1.0, epoch / max(1, cfg.kl_warmup_epochs))
        total_loss = 0.0
        total_count = 0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            yb_oh = one_hot_y(yb).to(device)

            num_true = xb[:, :num_dim] if num_dim > 0 else None
            cat_true = xb[:, num_dim:] if cat_dim > 0 else None

            num_pred, cat_logits, mu, logvar = model(xb, yb_oh)

            loss_num = mse(num_pred, num_true) if num_dim > 0 else torch.tensor(0.0, device=device)
            loss_cat = bce_logits(cat_logits, cat_true) if cat_dim > 0 else torch.tensor(0.0, device=device)

            kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = loss_num + loss_cat + kl_weight * kl

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * xb.size(0)
            total_count += xb.size(0)

        if epoch == 1 or epoch % max(1, cfg.epochs // 5) == 0 or epoch == cfg.epochs:
            logger.info(
                "Epoch %d/%d - loss=%.5f (kl_weight=%.3f)",
                epoch,
                cfg.epochs,
                total_loss / max(total_count, 1),
                kl_weight,
            )

    logger.info("Training complete")
    return model, device, one_hot_y


def _inverse_transform_synth(preprocessor, X_synth, feature_cols, numeric_cols, categorical_cols, num_dim):
    if hasattr(preprocessor, "inverse_transform"):
        X_synth_inv = preprocessor.inverse_transform(X_synth)
        return pd.DataFrame(X_synth_inv, columns=feature_cols)

    n_needed = X_synth.shape[0]
    num_synth = X_synth[:, :num_dim] if num_dim > 0 else np.zeros((n_needed, 0))
    cat_synth = X_synth[:, num_dim:] if X_synth.shape[1] > num_dim else np.zeros((n_needed, 0))

    data = {}
    if numeric_cols:
        num_pipe = preprocessor.named_transformers_["num"]
        scaler = num_pipe.named_steps.get("scaler")
        if scaler is not None:
            num_synth = scaler.inverse_transform(num_synth)
        num_cols_used = min(num_synth.shape[1], len(numeric_cols))
        for idx in range(num_cols_used):
            data[numeric_cols[idx]] = num_synth[:, idx]
        for col in numeric_cols[num_cols_used:]:
            data[col] = np.nan

    if categorical_cols:
        ohe = preprocessor.named_transformers_["cat"].named_steps["onehot"]
        cat_inv = ohe.inverse_transform(cat_synth)
        cat_cols_used = min(cat_inv.shape[1], len(categorical_cols))
        for idx in range(cat_cols_used):
            data[categorical_cols[idx]] = cat_inv[:, idx]
        for col in categorical_cols[cat_cols_used:]:
            data[col] = np.nan

    synth_features_df = pd.DataFrame(data)
    return synth_features_df.reindex(columns=feature_cols)


def _build_synthetic_rows(
    model,
    device,
    one_hot_y,
    n_needed,
    latent_dim,
    preprocessor,
    feature_cols,
    numeric_cols,
    categorical_cols,
    num_dim,
    cat_group_sizes,
):
    model.eval()
    with torch.no_grad():
        logger.info("Starting synthesis")
        z = torch.randn(n_needed, latent_dim, device=device)
        y_synth = torch.full((n_needed,), fill_value=1, dtype=torch.long, device=device)
        y_synth_oh = one_hot_y(y_synth).to(device)

        num_pred, cat_logits = model.decode(z, y_synth_oh)

        num_synth = num_pred.cpu().numpy() if num_dim > 0 else np.zeros((n_needed, 0), dtype=np.float32)

        if cat_logits is not None and cat_group_sizes:
            cat_logits_np = cat_logits.cpu().numpy()
            cat_onehot = np.zeros_like(cat_logits_np, dtype=np.float32)
            start = 0
            for gsize in cat_group_sizes:
                end = start + gsize
                idx = np.argmax(cat_logits_np[:, start:end], axis=1)
                cat_onehot[np.arange(n_needed), start + idx] = 1.0
                start = end
            X_synth = np.concatenate([num_synth, cat_onehot], axis=1)
        else:
            X_synth = num_synth

    return _inverse_transform_synth(
        preprocessor=preprocessor,
        X_synth=X_synth,
        feature_cols=feature_cols,
        numeric_cols=numeric_cols,
        categorical_cols=categorical_cols,
        num_dim=num_dim,
    )


def _finalize_synthetic_features(synth_features_df, working_df, numeric_cols, categorical_cols, numeric_stats):
    for col in categorical_cols:
        try:
            synth_features_df[col] = synth_features_df[col].astype(working_df[col].dtype)
        except Exception:
            synth_features_df[col] = synth_features_df[col].astype("object")

    for col in numeric_cols:
        synth_features_df[col] = pd.to_numeric(synth_features_df[col], errors="coerce")

    for col in numeric_cols:
        stats = numeric_stats.get(col)
        if stats is None:
            continue
        if stats["single_value"] is not None:
            synth_features_df[col] = stats["single_value"]
        else:
            synth_features_df[col] = synth_features_df[col].clip(lower=stats["min"], upper=stats["max"])
            if stats["all_int"]:
                synth_features_df[col] = synth_features_df[col].round()

    return synth_features_df


def run_cvae_oversample(input_dataset_name: str, output_dataset_name: str, config: dict):
    cfg = CVAEOversampleConfig.from_recipe_config(config)
    set_seed(cfg.random_seed)

    dataset = dataiku.Dataset(input_dataset_name)
    original_df = dataset.get_dataframe()
    logger.info("Dataset read")

    if cfg.target_column not in original_df.columns:
        raise ValueError("Target column '%s' not found in dataset" % cfg.target_column)

    working_df = original_df.copy()
    working_df, target_encoded, minority_value, majority_value = _prepare_target(working_df, cfg.target_column)
    base_output_df = working_df.copy()

    feature_cols, numeric_cols, categorical_cols = split_columns(working_df, cfg.target_column)
    if cfg.max_categories > 0 and categorical_cols:
        logger.info("Limiting categories to top %d per categorical column", cfg.max_categories)
        working_df = limit_categories(working_df, categorical_cols, cfg.max_categories)

    working_df, feature_cols, numeric_cols, categorical_cols = drop_empty_features(
        working_df, feature_cols, numeric_cols, categorical_cols
    )

    numeric_stats = _build_numeric_stats(base_output_df, numeric_cols)

    df_train = working_df
    if cfg.max_rows and len(df_train) > cfg.max_rows:
        df_train = df_train.sample(n=cfg.max_rows, random_state=cfg.random_seed)

    preprocessor = build_preprocessor(numeric_cols, categorical_cols)
    X_train = preprocessor.fit_transform(df_train[feature_cols])
    y_train = target_encoded.loc[df_train.index]
    logger.info("Dataset preprocessed")

    x_dim = X_train.shape[1]
    num_dim, cat_dim, cat_group_sizes = _extract_dims(
        preprocessor=preprocessor,
        numeric_cols=numeric_cols,
        categorical_cols=categorical_cols,
        x_dim=x_dim,
    )

    model, device, one_hot_y = _train_model(
        X_train=X_train,
        y_train=y_train,
        cfg=cfg,
        num_dim=num_dim,
        cat_dim=cat_dim,
        latent_dim=cfg.latent_dim,
        hidden_dims=cfg.hidden_dims,
        dropout=cfg.dropout,
    )

    full_counts = base_output_df[cfg.target_column].value_counts()
    n_min = int(full_counts.get(minority_value, 0))
    n_maj = int(full_counts.get(majority_value, 0))
    n_needed = max(0, n_maj - n_min)
    cap = int(cfg.max_synthetic_multiplier * len(base_output_df))
    n_needed = min(n_needed, cap)

    if n_needed == 0:
        output_df = base_output_df.copy()
    else:
        synth_features_df = _build_synthetic_rows(
            model=model,
            device=device,
            one_hot_y=one_hot_y,
            n_needed=n_needed,
            latent_dim=cfg.latent_dim,
            preprocessor=preprocessor,
            feature_cols=feature_cols,
            numeric_cols=numeric_cols,
            categorical_cols=categorical_cols,
            num_dim=num_dim,
            cat_group_sizes=cat_group_sizes,
        )

        synth_features_df = _finalize_synthetic_features(
            synth_features_df=synth_features_df,
            working_df=working_df,
            numeric_cols=numeric_cols,
            categorical_cols=categorical_cols,
            numeric_stats=numeric_stats,
        )

        synth_features_df[cfg.target_column] = minority_value
        output_df = pd.concat([base_output_df, synth_features_df], ignore_index=True)

    output_df["synthesized"] = False
    if n_needed > 0:
        output_df.loc[output_df.index[-n_needed:], "synthesized"] = True

    output_dataset = dataiku.Dataset(output_dataset_name)
    output_dataset.write_with_schema(output_df)
