import logging
import pandas as pd
import numpy as np
from snsynth import Synthesizer

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class SmartNoiseWrapper:
    """
    A unified wrapper for SmartNoise synthesizers (MWEM, MST, DP-CTGAN, PATE-CTGAN).
    
    Attributes:
        model_name (str): Name of the synthesizer ('mwem', 'mst', 'dpctgan', 'patectgan').
        epsilon (float): Privacy budget.
        model_kwargs (dict): Additional arguments for the specific synthesizer.
        synthesizer (object): The instantiated SmartNoise synthesizer object.
    """

    SUPPORTED_MODELS = ['mwem', 'mst', 'dpctgan', 'patectgan', 'aim']

    def __init__(self, model_name='dpctgan', epsilon=1.0, **model_kwargs):
        """
        Initialize the synthesizer wrapper.

        Args:
            model_name (str): The algorithm to use. Defaults to 'dpctgan'.
            epsilon (float): Privacy budget to be allocated to the synthesizer.
            **model_kwargs: Arbitrary keyword arguments passed to the underlying synthesizer 
                            (e.g., iterations=10 for MWEM, batch_size=500 for GANs).
        """
        if model_name.lower() not in self.SUPPORTED_MODELS:
            raise ValueError(f"Model '{model_name}' not supported. Choose from {self.SUPPORTED_MODELS}")

        self.model_name = model_name.lower()
        self.epsilon = epsilon
        self.model_kwargs = model_kwargs
        self.synthesizer = None
        self._is_fitted = False

    def fit(self, df, categorical_columns=None, continuous_columns=None, preprocessor_eps=0.0):
        """
        Fit the synthesizer to the private data.

        Args:
            df (pd.DataFrame): The sensitive dataset.
            categorical_columns (list, optional): List of column names to treat as categorical. 
                                                  If None, inferred automatically.
            continuous_columns (list, optional): List of column names to treat as continuous.
            preprocessor_eps (float): Budget to spend on data preprocessing (e.g. scaling).
                                      Subtracted from the total epsilon.
        """
        logger.info(f"Initializing {self.model_name.upper()} with epsilon={self.epsilon}...")

        # MWEM and MST often benefit from explicit categorical column definitions
        if categorical_columns is None:
            # Simple heuristic: object/category dtypes are categorical
            categorical_columns = df.select_dtypes(include=['object', 'category']).columns.tolist()
            logger.info(f"Automatically inferred categorical columns: {categorical_columns}")

        # Create the synthesizer via the factory
        try:
            self.synthesizer = Synthesizer.create(
                self.model_name,
                epsilon=self.epsilon,
                **self.model_kwargs
            )
        except Exception as e:
            logger.error(f"Failed to create synthesizer: {e}")
            raise e

        logger.info(f"Fitting model on data with shape {df.shape}...")
        
        # Fit logic
        try:
            self.synthesizer.fit(
                df,
                categorical_columns=categorical_columns,
                continuous_columns=continuous_columns,
                preprocessor_eps=preprocessor_eps
            )
            self._is_fitted = True
            logger.info("Model fitting complete.")
        except Exception as e:
            logger.error(f"Error during fitting: {e}")
            raise e

    def sample(self, n_samples):
        """
        Generate synthetic data.

        Args:
            n_samples (int): Number of rows to generate.

        Returns:
            pd.DataFrame: A dataframe of synthetic data.
        """
        if not self._is_fitted:
            raise RuntimeError("The model is not fitted. Please run .fit() first.")

        logger.info(f"Sampling {n_samples} rows...")
        
        # Some SmartNoise models return lists of lists, others return DataFrames.
        # We ensure a consistent DataFrame output.
        try:
            synthetic_data = self.synthesizer.sample(n_samples)
            
            # Ensure it's a pandas DataFrame (SmartNoise usually handles this, but for safety)
            if not isinstance(synthetic_data, pd.DataFrame):
                # If the synthesizer captured column names, use them
                cols = getattr(self.synthesizer, 'columns', None)
                synthetic_data = pd.DataFrame(synthetic_data, columns=cols)
            
            return synthetic_data
        except Exception as e:
            logger.error(f"Error during sampling: {e}")
            raise e

    def save(self, path):
        """Save the fitted model (if supported by the underlying synthesizer)."""
        import pickle
        if not self._is_fitted:
            logger.warning("Saving an unfitted model.")
        
        with open(path, 'wb') as f:
            pickle.dump(self.synthesizer, f)
        logger.info(f"Model saved to {path}")

    @staticmethod
    def load(path):
        """Load a pickled synthesizer."""
        import pickle
        with open(path, 'rb') as f:
            model = pickle.load(f)
        
        # Reconstruct wrapper
        wrapper = SmartNoiseWrapper()
        wrapper.synthesizer = model
        wrapper._is_fitted = True
        return wrapper