import logging

import dataiku
from dataiku.customrecipe import get_input_names_for_role, get_output_names_for_role, get_recipe_config

from smartnoise import SmartNoiseWrapper


def _load_dataset_from_role(role_name):
	dataset_names = get_input_names_for_role(role_name)
	if not dataset_names:
		raise ValueError(f"No dataset provided for role '{role_name}'")
	return dataiku.Dataset(dataset_names[0])


logger = logging.getLogger(__name__)


config = get_recipe_config() or {}

model_name = config.get('model_name', 'dpctgan')
epsilon = float(config.get('epsilon', 1.0))
sample_rows = int(config.get('sample_rows', 1000))
auto_detect_schema = config.get('auto_detect_schema', True)
preprocessor_eps = float(config.get('preprocessor_epsilon', 0.5))
categorical_columns = None
continuous_columns = config.get('continuous_columns') or None
excluded_columns = config.get('excluded_columns') or []
show_advanced = bool(config.get('show_advanced', False))
model_kwargs = {}

if sample_rows <= 0:
	raise ValueError("'Rows to generate' must be a positive integer")

supported_models = {'dpctgan', 'patectgan', 'mwem'}
if model_name not in supported_models:
	raise ValueError(f"Model '{model_name}' is not supported. Choose from {sorted(supported_models)}")

def _build_model_kwargs(name, cfg, use_advanced):
	if name == 'mwem':
		if not use_advanced:
			return {}
		iterations = int(cfg.get('mwem_iterations', 50))
		split_factor = int(cfg.get('mwem_split_factor', 3))
		marginal_width = int(cfg.get('mwem_marginal_width', 2))
		if iterations <= 0 or split_factor <= 0 or marginal_width <= 0:
			raise ValueError('MWEM parameters must be positive integers')
		return {
			'iterations': iterations,
			'split_factor': split_factor,
			'marginal_width': marginal_width,
		}
	if name == 'dpctgan':
		if not use_advanced:
			return {}
		sigma = float(cfg.get('dpctgan_sigma', 1.0))
		batch_size = int(cfg.get('dpctgan_batch_size', 512))
		epochs = int(cfg.get('dpctgan_epochs', 100))
		if sigma <= 0 or batch_size <= 0 or epochs <= 0:
			raise ValueError('DP-CTGAN parameters must be positive')
		return {
			'sigma': sigma,
			'batch_size': batch_size,
			'epochs': epochs,
		}
	if name == 'patectgan':
		if not use_advanced:
			return {}
		teacher_iters = int(cfg.get('pate_teacher_iters', 5))
		student_iters = int(cfg.get('pate_student_iters', 1))
		sample_per_teacher = int(cfg.get('pate_sample_per_teacher', 1000))
		category_epsilon_pct = float(cfg.get('pate_category_epsilon_pct', 0.5))
		if teacher_iters <= 0 or student_iters <= 0 or sample_per_teacher <= 0:
			raise ValueError('PATE-CTGAN parameters must be positive')
		if not 0 <= category_epsilon_pct <= 1:
			raise ValueError('PATE categorical epsilon % must be between 0 and 1')
		return {
			'teacher_iters': teacher_iters,
			'student_iters': student_iters,
			'sample_per_teacher': sample_per_teacher,
			'category_epsilon_pct': category_epsilon_pct,
		}
	return {}

model_kwargs = _build_model_kwargs(model_name, config, show_advanced)

sensitive_dataset = _load_dataset_from_role('sensitive_dataset')
private_df = sensitive_dataset.get_dataframe()

if excluded_columns:
	private_df = private_df.drop(columns=excluded_columns, errors='ignore')

if auto_detect_schema:
	numeric_columns = private_df.select_dtypes(include=['number']).columns.tolist()
	bool_columns = private_df.select_dtypes(include=['bool']).columns.tolist()
	continuous_columns = [col for col in numeric_columns if col not in bool_columns]
	categorical_columns = [col for col in private_df.columns if col not in continuous_columns]
	logger.info("Auto-detected continuous columns (%d): %s", len(continuous_columns), continuous_columns)
	logger.info("Auto-detected categorical columns (%d): %s", len(categorical_columns), categorical_columns)
	continuous_columns = continuous_columns or None
	categorical_columns = categorical_columns or None
else:
	if continuous_columns:
		continuous_columns = [col for col in continuous_columns if col in private_df.columns] or None
	categorical_candidates = [col for col in private_df.columns if not continuous_columns or col not in continuous_columns]
	categorical_columns = categorical_candidates or None
	logger.info(
		"Manual schema: %d continuous / %d categorical (categorical inferred as remaining columns)",
		len(continuous_columns or []),
		len(categorical_candidates),
	)

wrapper = SmartNoiseWrapper(
	model_name=model_name,
	epsilon=epsilon,
	**model_kwargs,
)

wrapper.fit(
	private_df,
	categorical_columns=categorical_columns,
	continuous_columns=continuous_columns,
	preprocessor_eps=preprocessor_eps,
)

synthetic_df = wrapper.sample(sample_rows)

output_names = get_output_names_for_role('synthetic_output')
if not output_names:
	raise ValueError("No dataset provided for output role 'synthetic_output'")
synthetic_dataset = dataiku.Dataset(output_names[0])
synthetic_dataset.write_with_schema(synthetic_df)