import pandas as pd
import numpy as np
import logging
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.ensemble import IsolationForest
from scipy import stats
from scipy.stats import pearsonr, spearmanr, f_oneway, zscore
from typing import Dict, List, Optional, Union, Tuple


#============== Common Utilities ==============

def identify_column_types(X: pd.DataFrame) -> Tuple[List[str], List[str], Dict[str, str]]:
    """
    Identifies numeric and categorical columns in the DataFrame.
    
    Returns
    -------
    Tuple[List[str], List[str], Dict[str, str]]
        (numeric_cols, categorical_cols, column_types)
    """
    numeric_cols = []
    categorical_cols = []
    column_types = {}
    
    for col in X.columns:
        # Check if column is numeric
        if pd.api.types.is_numeric_dtype(X[col]):
            # Check if it's truly numeric or an encoded category
            n_unique = X[col].nunique()
            n_total = len(X[col])
            
            # If few unique values relative to size, might be categorical
            if n_unique <= min(20, n_total * 0.1) and n_unique > 0:
                # Check if values are consecutive integers (likely label encoding)
                unique_vals = sorted(X[col].dropna().unique())
                if len(unique_vals) > 1 and all(isinstance(v, (int, np.integer)) for v in unique_vals):
                    if unique_vals == list(range(len(unique_vals))):
                        categorical_cols.append(col)
                        column_types[col] = "categorical"
                        continue
            
            numeric_cols.append(col)
            column_types[col] = "numeric"
        else:
            categorical_cols.append(col)
            column_types[col] = "categorical"
    
    return numeric_cols, categorical_cols, column_types


#============== Clustering Utilities ==============

def prepare_mixed_data(
    X: pd.DataFrame,
    numeric_cols: List[str],
    categorical_cols: List[str],
    max_categories_for_onehot: int,
    standardize: bool
) -> Tuple[np.ndarray, Dict]:
    """
    Prepares mixed data (numeric + categorical) for clustering.
    
    Returns
    -------
    Tuple[np.ndarray, Dict]
        (X_encoded, encoders_info)
    """
    encoded_parts = []
    encoders_info = {
        "numeric_cols": numeric_cols,
        "categorical_cols": categorical_cols,
        "encoders": {}
    }
    
    # Process numeric columns
    if numeric_cols:
        X_numeric = X[numeric_cols].copy()
        
        # Handle missing values
        X_numeric = X_numeric.fillna(X_numeric.mean())
        
        # Standardization
        scaler = None
        if standardize:
            scaler = StandardScaler()
            X_numeric_scaled = scaler.fit_transform(X_numeric)
            encoders_info["scaler"] = scaler
        else:
            X_numeric_scaled = X_numeric.values
        
        encoded_parts.append(X_numeric_scaled)
    
    # Process categorical columns
    if categorical_cols:
        for col in categorical_cols:
            X_cat = X[[col]].copy()
            
            # Handle missing values (mode)
            mode_val = X_cat[col].mode()
            if len(mode_val) > 0:
                X_cat[col] = X_cat[col].fillna(mode_val[0])
            else:
                X_cat[col] = X_cat[col].fillna("Unknown")
            
            n_unique = X_cat[col].nunique()
            
            # One-hot encoding if few categories
            if n_unique <= max_categories_for_onehot:
                ohe = OneHotEncoder(sparse_output=False, drop='first', handle_unknown='ignore')
                X_cat_encoded = ohe.fit_transform(X_cat)
                encoders_info["encoders"][col] = {
                    "type": "onehot",
                    "encoder": ohe,
                    "feature_names": ohe.get_feature_names_out([col]).tolist()
                }
            else:
                # Label encoding for many categories
                le = LabelEncoder()
                X_cat_encoded = le.fit_transform(X_cat[col].astype(str)).reshape(-1, 1)
                encoders_info["encoders"][col] = {
                    "type": "label",
                    "encoder": le,
                    "feature_names": [col]
                }
            
            encoded_parts.append(X_cat_encoded)
    
    # Concatenate all features
    if encoded_parts:
        X_encoded = np.hstack(encoded_parts)
    else:
        raise ValueError("No valid columns found for clustering")
    
    return X_encoded, encoders_info


def extract_numeric_centers(
    cluster_centers: np.ndarray,
    encoders_info: Dict,
    numeric_cols: List[str],
    standardize: bool
) -> np.ndarray:
    """
    Extracts cluster centers for numeric columns only.
    """
    if not numeric_cols:
        return np.array([]).reshape(len(cluster_centers), 0)
    
    # First columns correspond to numeric features
    n_numeric = len(numeric_cols)
    numeric_centers = cluster_centers[:, :n_numeric]
    
    # Denormalize if standardized
    if standardize and "scaler" in encoders_info:
        scaler = encoders_info["scaler"]
        numeric_centers = scaler.inverse_transform(numeric_centers)
    
    return numeric_centers


def compute_categorical_distributions(
    df: pd.DataFrame,
    categorical_cols: List[str],
    cluster_labels: np.ndarray,
    n_clusters: int
) -> Dict:
    """
    Computes the distribution of categories per cluster.
    """
    distributions = {}
    
    for col in categorical_cols:
        distributions[col] = {}
        for cluster_id in range(n_clusters):
            cluster_data = df[cluster_labels == cluster_id][col]
            value_counts = cluster_data.value_counts(normalize=True).to_dict()
            distributions[col][f"Cluster_{cluster_id}"] = {
                "distribution": value_counts,
                "most_common": cluster_data.mode().tolist()[:3] if len(cluster_data) > 0 else []
            }
    
    return distributions


def compute_clustering_feature_importance(
    df: pd.DataFrame,
    numeric_cols: List[str],
    categorical_cols: List[str],
    cluster_labels: np.ndarray,
    n_clusters: int,
    encoders_info: Dict
) -> Dict[str, float]:
    """
    Computes feature importance (numeric and categorical).
    """
    feature_importance = {}
    
    # Importance of numeric columns
    for col in numeric_cols:
        between_cluster_var = np.var([df[cluster_labels == i][col].mean() 
                                      for i in range(n_clusters)])
        within_cluster_var = np.mean([df[cluster_labels == i][col].var() 
                                     for i in range(n_clusters)])
        if within_cluster_var > 0:
            feature_importance[col] = round(between_cluster_var / within_cluster_var, 3)
        else:
            feature_importance[col] = 0.0
    
    # Importance of categorical columns (based on entropy)
    for col in categorical_cols:
        # Global entropy
        global_dist = df[col].value_counts(normalize=True)
        global_entropy = -np.sum(global_dist * np.log2(global_dist + 1e-10))
        
        # Average entropy per cluster
        cluster_entropies = []
        for cluster_id in range(n_clusters):
            cluster_data = df[cluster_labels == cluster_id][col]
            if len(cluster_data) > 0:
                cluster_dist = cluster_data.value_counts(normalize=True)
                cluster_entropy = -np.sum(cluster_dist * np.log2(cluster_dist + 1e-10))
                cluster_entropies.append(cluster_entropy)
        
        avg_cluster_entropy = np.mean(cluster_entropies) if cluster_entropies else 0
        
        # Importance = entropy reduction (higher value means better cluster separation)
        if global_entropy > 0:
            importance = (global_entropy - avg_cluster_entropy) / global_entropy
            feature_importance[col] = round(importance, 3)
        else:
            feature_importance[col] = 0.0
    
    # Sort by descending importance
    return dict(sorted(feature_importance.items(), key=lambda x: x[1], reverse=True))


def find_optimal_clusters(X: np.ndarray, max_clusters: int, random_state: int) -> int:
    """
    Finds the optimal number of clusters using the elbow method.
    """
    inertias = []
    K_range = range(2, min(max_clusters + 1, len(X)))
    
    for k in K_range:
        kmeans = KMeans(n_clusters=k, random_state=random_state, n_init=10)
        kmeans.fit(X)
        inertias.append(kmeans.inertia_)
    
    # Calculate second-order differences to find the "elbow"
    if len(inertias) >= 3:
        second_diff = np.diff(inertias, n=2)
        optimal_k = K_range[np.argmax(second_diff) + 1] if len(second_diff) > 0 else K_range[0]
    else:
        optimal_k = K_range[0]
    
    return optimal_k


def interpret_quality_metrics(silhouette: float, davies_bouldin: float) -> str:
    """
    Interprets clustering quality metrics.
    """
    interpretations = []
    
    # Silhouette score interpretation
    if silhouette > 0.5:
        interpretations.append("Excellent clustering (silhouette > 0.5)")
    elif silhouette > 0.3:
        interpretations.append("Good clustering (silhouette > 0.3)")
    elif silhouette > 0.1:
        interpretations.append("Acceptable clustering (silhouette > 0.1)")
    else:
        interpretations.append("Poor quality clustering (silhouette <= 0.1)")
    
    # Davies-Bouldin score interpretation
    if davies_bouldin < 1.0:
        interpretations.append("Well-separated clusters (Davies-Bouldin < 1.0)")
    elif davies_bouldin < 2.0:
        interpretations.append("Moderately separated clusters (Davies-Bouldin < 2.0)")
    else:
        interpretations.append("Poorly separated clusters (Davies-Bouldin >= 2.0)")
    
    return " | ".join(interpretations)


def generate_clustering_insights_summary(
    n_clusters: int,
    cluster_sizes: Dict[int, int],
    cluster_stats: Dict,
    categorical_distributions: Dict,
    silhouette: float,
    davies_bouldin: float,
    feature_importance: Dict[str, float],
    column_types: Dict[str, str]
) -> str:
    """
    Generates a textual summary of the main clustering insights.
    """
    total = sum(cluster_sizes.values())
    
    summary_parts = [
        f"Clustering performed with {n_clusters} clusters on {total} observations."
    ]
    
    # Column types used
    n_numeric = sum(1 for t in column_types.values() if t == "numeric")
    n_categorical = sum(1 for t in column_types.values() if t == "categorical")
    type_info = []
    if n_numeric > 0:
        type_info.append(f"{n_numeric} numeric column(s)")
    if n_categorical > 0:
        type_info.append(f"{n_categorical} categorical column(s)")
    if type_info:
        summary_parts.append(f"Features used: {', '.join(type_info)}.")
    
    summary_parts.append(
        f"Clustering quality: Silhouette score = {silhouette:.3f}, Davies-Bouldin score = {davies_bouldin:.3f}."
    )
    
    # Cluster distribution
    size_info = []
    for cluster_id in sorted(cluster_sizes.keys()):
        size = cluster_sizes[cluster_id]
        pct = (size / total) * 100
        size_info.append(f"Cluster {cluster_id}: {size} observations ({pct:.1f}%)")
    
    summary_parts.append(f"Distribution: {', '.join(size_info)}.")
    
    # Most important features
    top_features = list(feature_importance.items())[:3]
    if top_features:
        top_feat_names = [f"{name} (importance: {imp:.2f})" for name, imp in top_features]
        summary_parts.append(f"Most discriminative features: {', '.join(top_feat_names)}.")
    
    # Insights on dominant categories per cluster
    if categorical_distributions:
        cat_insights = []
        for col, dists in list(categorical_distributions.items())[:2]:  # Limit to 2 columns
            for cluster_id in range(n_clusters):
                cluster_key = f"Cluster_{cluster_id}"
                if cluster_key in dists and dists[cluster_key]["most_common"]:
                    most_common = dists[cluster_key]["most_common"][0]
                    pct = dists[cluster_key]["distribution"].get(most_common, 0) * 100
                    if pct > 50:  # Only if a category dominates (>50%)
                        cat_insights.append(
                            f"Cluster {cluster_id} dominated by '{most_common}' ({pct:.1f}%) in {col}"
                        )
        if cat_insights:
            summary_parts.append(" | ".join(cat_insights[:3]) + ".")  # Limit to 3 insights
    
    return " ".join(summary_parts)


#============== Outliers Detection Utilities ==============

def detect_numeric_outliers(
    df: pd.DataFrame,
    col: str,
    methods: List[str],
    contamination: float
) -> Dict:
    """Detects outliers in numeric columns using Isolation Forest."""
    col_data = df[col].dropna()
    
    if len(col_data) < 10:  # Isolation Forest requires at least 10 samples
        return None
    
    # Only use Isolation Forest (methods parameter kept for compatibility but should always be ['isolation_forest'])
    if "isolation_forest" not in methods:
        return None
    
    outlier_indices = set()
    
    try:
        iso_forest = IsolationForest(contamination=contamination, random_state=42)
        col_data_2d = col_data.values.reshape(-1, 1)
        predictions = iso_forest.fit_predict(col_data_2d)
        
        iso_outliers = col_data[predictions == -1]
        iso_indices = set(iso_outliers.index)
        outlier_indices.update(iso_indices)
    except Exception:
        return None
    
    if len(outlier_indices) == 0:
        return None
    
    # Get top 10 outlier values (both high and low extremes)
    outlier_values = col_data.loc[list(outlier_indices)]
    if len(outlier_values) > 0:
        # Get top 10 highest and top 10 lowest, then combine and take top 10 overall
        top_high = outlier_values.nlargest(10).tolist()
        top_low = outlier_values.nsmallest(10).tolist()
        # Combine and get unique values, sorted
        all_extremes = sorted(set(top_high + top_low), reverse=True)[:10]
    else:
        all_extremes = []
    
    # Compute boxplot statistics
    q1 = float(col_data.quantile(0.25))
    median = float(col_data.median())
    q3 = float(col_data.quantile(0.75))
    min_val = float(col_data.min())
    max_val = float(col_data.max())
    mean_val = float(col_data.mean())
    
    # Compute IQR for whiskers
    iqr = q3 - q1
    lower_whisker = max(min_val, q1 - 1.5 * iqr)
    upper_whisker = min(max_val, q3 + 1.5 * iqr)
    
    return {
        "column": col,
        "outlier_count": len(outlier_indices),
        "outlier_percentage": round(len(outlier_indices) / len(col_data) * 100, 2),
        "method_used": "isolation_forest",
        "top_outlier_values": [float(v) for v in all_extremes],
        "outlier_indices": list(outlier_indices),  # Keep temporarily for pattern calculation
        "boxplot_stats": {
            "min": min_val,
            "q1": q1,
            "median": median,
            "q3": q3,
            "max": max_val,
            "mean": mean_val,
            "lower_whisker": lower_whisker,
            "upper_whisker": upper_whisker
        }
    }


def detect_categorical_outliers(df: pd.DataFrame, col: str) -> Dict:
    """Detects outliers in categorical columns (rare categories)."""
    col_data = df[col].dropna()
    
    if len(col_data) == 0:
        return None
    
    value_counts = col_data.value_counts()
    total = len(col_data)
    
    # Consider categories with frequency < 1% as outliers
    threshold = max(1, int(total * 0.01))
    rare_categories = value_counts[value_counts < threshold]
    
    if len(rare_categories) == 0:
        return None
    
    # Find indices of rare categories
    outlier_indices = set()
    for category in rare_categories.index:
        outlier_indices.update(col_data[col_data == category].index)
    
    # Get top 10 rare categories (most rare first)
    top_rare_categories = rare_categories.head(10).index.tolist()
    
    return {
        "column": col,
        "outlier_count": len(outlier_indices),
        "outlier_percentage": round(len(outlier_indices) / total * 100, 2),
        "method_used": "rare_category_detection",
        "top_outlier_values": top_rare_categories,  # Top 10 rare categories
        "outlier_indices": list(outlier_indices)  # Keep temporarily for pattern calculation
    }


def compute_outlier_summary(
    df: pd.DataFrame,
    outliers_by_column: Dict,
    all_outlier_indices: set
) -> Dict:
    """Computes overall summary of outlier detection."""
    total_rows = len(df)
    rows_with_outliers = len(all_outlier_indices)
    outlier_rate = rows_with_outliers / total_rows if total_rows > 0 else 0
    
    # Determine impact label based on outlier rate
    if outlier_rate > 0.10:
        impact_label = "High"
    elif outlier_rate > 0.05:
        impact_label = "Medium"
    else:
        impact_label = "Low"
    
    summary = {
        "outlier_rate": round(outlier_rate, 4),  # As decimal (0.038 for 3.8%)
        "outlier_count": rows_with_outliers,
        "columns_with_outliers": len(outliers_by_column),
        "impact_label": impact_label
    }
    
    return summary


def extract_outlier_records(
    df: pd.DataFrame,
    outlier_indices: List[int],
    feature_columns: List[str]
) -> List[Dict]:
    """Extracts actual outlier records."""
    if not outlier_indices:
        return []
    
    outlier_df = df.loc[outlier_indices, feature_columns]
    return outlier_df.to_dict('records')[:100]  # Limit to 100 records


def identify_outlier_patterns(
    df: pd.DataFrame,
    outliers_by_column: Dict,
    all_outlier_indices: set
) -> Dict:
    """Identifies patterns in outlier detection."""
    # Find rows that are outliers in multiple columns
    outlier_counts = {}
    for col_info in outliers_by_column.values():
        for idx in col_info.get("outlier_indices", []):
            outlier_counts[idx] = outlier_counts.get(idx, 0) + 1
    
    # Count multi-column outliers (rows that are outliers in 2+ columns)
    multi_column = {idx: count for idx, count in outlier_counts.items() if count > 1}
    multi_column_outliers_count = len(multi_column)
    
    # Get top 3 columns with most outliers
    column_outlier_counts = {
        col: info.get("outlier_count", 0) 
        for col, info in outliers_by_column.items()
    }
    most_common_columns = sorted(
        column_outlier_counts.items(), 
        key=lambda x: x[1], 
        reverse=True
    )[:3]
    most_common_outlier_columns = [col for col, count in most_common_columns]
    
    return {
        "multi_column_outliers_count": multi_column_outliers_count,
        "most_common_outlier_columns": most_common_outlier_columns
    }


def generate_outlier_recommendations(
    outlier_summary: Dict,
    outliers_by_column: Dict,
    outlier_patterns: Dict
) -> List[str]:
    """Generates recommendations based on outlier findings."""
    recommendations = []
    
    outlier_rate = outlier_summary.get("outlier_rate", 0)
    outlier_pct = outlier_rate * 100  # Convert to percentage for display
    
    if outlier_pct > 10:
        recommendations.append(
            f"High proportion of outliers detected ({outlier_pct:.1f}%). "
            "Consider investigating data quality or collection methods."
        )
    elif outlier_pct > 5:
        recommendations.append(
            f"Moderate proportion of outliers detected ({outlier_pct:.1f}%). "
            "Review outlier records to determine if they represent valid extreme values or errors."
        )
    
    multi_column_count = outlier_patterns.get("multi_column_outliers_count", 0)
    if multi_column_count > 0:
        recommendations.append(
            f"{multi_column_count} rows identified as outliers across multiple columns. "
            "These may indicate systematic data quality issues."
        )
    
    # Get top column with most outliers
    if outliers_by_column:
        top_col = max(outliers_by_column.items(), key=lambda x: x[1].get("outlier_percentage", 0))
        recommendations.append(
            f"Column '{top_col[0]}' has the highest outlier rate "
            f"({top_col[1].get('outlier_percentage', 0):.1f}%). Consider detailed investigation."
        )
    
    if not recommendations:
        recommendations.append(
            "Outlier detection completed. Outlier rates are within expected ranges."
        )
    
    return recommendations


def generate_outlier_insights_summary(
    outlier_summary: Dict,
    outliers_by_column: Dict,
    outlier_patterns: Dict
) -> str:
    """Generates textual summary of outlier detection insights."""
    summary_parts = []
    
    outlier_count = outlier_summary.get("outlier_count", 0)
    outlier_rate = outlier_summary.get("outlier_rate", 0)
    outlier_pct = outlier_rate * 100  # Convert to percentage for display
    
    summary_parts.append(
        f"Outlier detection identified {outlier_count} rows with outliers "
        f"({outlier_pct:.1f}%)."
    )
    
    cols_affected = outlier_summary.get("columns_with_outliers", 0)
    summary_parts.append(f"{cols_affected} columns contain outliers.")
    
    multi_column_count = outlier_patterns.get("multi_column_outliers_count", 0)
    if multi_column_count > 0:
        summary_parts.append(
            f"{multi_column_count} rows are outliers across multiple columns, "
            "suggesting potential systematic issues."
        )
    
    # Get top column with most outliers
    if outliers_by_column:
        top_col = max(outliers_by_column.items(), key=lambda x: x[1].get("outlier_percentage", 0))
        summary_parts.append(
            f"Most affected column: {top_col[0]} "
            f"({top_col[1].get('outlier_percentage', 0):.1f}% outliers)."
        )
    
    return " ".join(summary_parts)


#============== Root Cause Analysis Utilities ==============

def compute_target_summary(df: pd.DataFrame, target_column: str) -> Dict:
    """Computes summary statistics for the target variable."""
    target_data = df[target_column]
    
    summary = {
        "mean": float(target_data.mean()),
        "median": float(target_data.median()),
        "std": float(target_data.std()),
        "min": float(target_data.min()),
        "max": float(target_data.max()),
        "count": int(len(target_data)),
        "missing_count": int(df[target_column].isnull().sum())
    }
    
    # Add percentiles
    for p in [25, 50, 75, 90, 95]:
        summary[f"p{p}"] = float(target_data.quantile(p / 100))
    
    # Check if target is binary
    unique_vals = target_data.nunique()
    if unique_vals == 2:
        summary["is_binary"] = True
        summary["value_counts"] = target_data.value_counts().to_dict()
    else:
        summary["is_binary"] = False
    
    return summary


def compute_rca_feature_importance(
    df: pd.DataFrame,
    target_column: str,
    numeric_cols: List[str],
    categorical_cols: List[str]
) -> Dict[str, float]:
    """Computes feature importance for root cause analysis."""
    importance = {}
    
    # For numeric features: use correlation strength
    for col in numeric_cols:
        try:
            corr, p_value = pearsonr(df[col].dropna(), df.loc[df[col].notna(), target_column])
            if not np.isnan(corr):
                # Combine correlation strength and significance
                significance = 1 - min(p_value, 0.05) / 0.05  # Normalize p-value
                importance[col] = round(abs(corr) * significance, 4)
            else:
                importance[col] = 0.0
        except:
            importance[col] = 0.0
    
    # For categorical features: use variance explained
    for col in categorical_cols:
        try:
            # Group by category and compute target variance
            groups = [df[df[col] == cat][target_column].values 
                     for cat in df[col].dropna().unique() 
                     if len(df[df[col] == cat]) > 1]
            
            if len(groups) >= 2:
                # ANOVA F-statistic as importance measure
                f_stat, p_value = f_oneway(*groups)
                if not np.isnan(f_stat) and f_stat > 0:
                    significance = 1 - min(p_value, 0.05) / 0.05
                    # Normalize F-statistic (using log scale)
                    normalized_f = min(np.log1p(f_stat) / 10, 1.0)
                    importance[col] = round(normalized_f * significance, 4)
                else:
                    importance[col] = 0.0
            else:
                importance[col] = 0.0
        except:
            importance[col] = 0.0
    
    # Sort by importance
    return dict(sorted(importance.items(), key=lambda x: x[1], reverse=True))


def compute_correlations(
    df: pd.DataFrame,
    target_column: str,
    numeric_cols: List[str]
) -> Dict:
    """Computes correlation analysis between target and numeric features."""
    correlations = {
        "pearson": {},
        "spearman": {},
        "top_correlated": []
    }
    
    for col in numeric_cols:
        try:
            # Remove missing values
            valid_mask = df[[col, target_column]].notna().all(axis=1)
            if valid_mask.sum() < 3:
                continue
            
            x = df.loc[valid_mask, col]
            y = df.loc[valid_mask, target_column]
            
            # Pearson correlation
            pearson_corr, pearson_p = pearsonr(x, y)
            
            # Spearman correlation (rank-based)
            spearman_corr, spearman_p = spearmanr(x, y)
            
            if not (np.isnan(pearson_corr) or np.isnan(spearman_corr)):
                correlations["pearson"][col] = {
                    "correlation": round(float(pearson_corr), 4),
                    "p_value": round(float(pearson_p), 6),
                    "significant": pearson_p < 0.05
                }
                
                correlations["spearman"][col] = {
                    "correlation": round(float(spearman_corr), 4),
                    "p_value": round(float(spearman_p), 6),
                    "significant": spearman_p < 0.05
                }
        except:
            continue
    
    # Top correlated features
    if correlations["pearson"]:
        top_pearson = sorted(
            correlations["pearson"].items(),
            key=lambda x: abs(x[1]["correlation"]),
            reverse=True
        )[:10]
        correlations["top_correlated"] = [
            {
                "feature": feat,
                "pearson_corr": vals["correlation"],
                "pearson_p": vals["p_value"],
                "spearman_corr": correlations["spearman"].get(feat, {}).get("correlation", 0),
                "significant": vals["significant"]
            }
            for feat, vals in top_pearson
        ]
    
    return correlations


def compute_categorical_impacts(
    df: pd.DataFrame,
    target_column: str,
    categorical_cols: List[str],
    max_top_categories: int = 5
) -> Dict:
    """
    Computes impact of categorical features on target variable.
    Returns only top categories to avoid saturating agent context.
    
    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing the data
    target_column : str
        Name of the target variable
    categorical_cols : List[str]
        List of categorical column names
    max_top_categories : int, default=5
        Maximum number of top categories to keep per feature
        
    Returns
    -------
    Dict
        Dictionary with simplified categorical impacts containing:
        - f_ratio: F-ratio for variance between categories
        - n_categories: Total number of categories
        - most_impactful_category: Category with highest absolute difference
        - top_categories: List of top N categories with simplified stats
    """
    impacts = {}
    
    for col in categorical_cols:
        try:
            category_stats = {}
            overall_mean = df[target_column].mean()
            
            # Compute stats for all categories first (needed for f_ratio calculation)
            all_category_stats = {}
            for category in df[col].dropna().unique():
                cat_data = df[df[col] == category][target_column]
                if len(cat_data) > 0:
                    all_category_stats[str(category)] = {
                        "mean": float(cat_data.mean()),
                        "std": float(cat_data.std()),
                        "difference_from_overall": float(cat_data.mean() - overall_mean)
                    }
            
            # Compute variance between categories (needs all categories)
            if len(all_category_stats) > 1:
                category_means = [stats["mean"] for stats in all_category_stats.values()]
                between_variance = np.var(category_means)
                within_variance = np.mean([stats["std"]**2 for stats in all_category_stats.values()])
                
                if within_variance > 0:
                    f_ratio = between_variance / within_variance
                else:
                    f_ratio = 0.0
                
                # Find most impactful category
                most_impactful = max(
                    all_category_stats.items(),
                    key=lambda x: abs(x[1]["difference_from_overall"])
                )[0] if all_category_stats else None
                
                # Keep only top categories sorted by absolute difference
                sorted_categories = sorted(
                    all_category_stats.items(),
                    key=lambda x: abs(x[1]["difference_from_overall"]),
                    reverse=True
                )[:max_top_categories]
                
                # Build simplified top_categories list
                top_categories = []
                for cat_name, cat_stats in sorted_categories:
                    pct_diff = (cat_stats["difference_from_overall"] / overall_mean * 100) if overall_mean != 0 else 0.0
                    top_categories.append({
                        "category": cat_name,
                        "mean": cat_stats["mean"],
                        "difference_from_overall": cat_stats["difference_from_overall"],
                        "pct_difference": round(float(pct_diff), 2)
                    })
                
                impacts[col] = {
                    "f_ratio": round(float(f_ratio), 4),
                    "n_categories": len(all_category_stats),
                    "most_impactful_category": most_impactful,
                    "top_categories": top_categories
                }
        except Exception as e:
            logging.exception(f"Error processing categorical variable for root cause analysis")
            continue
    
    return impacts


def compute_feature_interactions(
    df: pd.DataFrame,
    target_column: str,
    feature_columns: List[str],
    numeric_cols: List[str],
    categorical_cols: List[str],
    max_interactions: int
) -> List[Dict]:
    """Identifies important feature interactions."""
    interactions = []
    
    # Analyze numeric-numeric interactions
    for i, col1 in enumerate(numeric_cols[:5]):  # Limit to avoid too many combinations
        for col2 in numeric_cols[i+1:6]:
            try:
                # Create interaction term (product)
                interaction = df[col1] * df[col2]
                corr, p_value = pearsonr(interaction.dropna(), 
                                       df.loc[interaction.notna(), target_column])
                
                if not np.isnan(corr) and abs(corr) > 0.1:
                    interactions.append({
                        "type": "numeric_numeric",
                        "features": [col1, col2],
                        "correlation": round(float(corr), 4),
                        "p_value": round(float(p_value), 6),
                        "significant": p_value < 0.05
                    })
            except:
                continue
    
    # Analyze categorical-numeric interactions
    for cat_col in categorical_cols[:3]:  # Limit categories
        for num_col in numeric_cols[:5]:
            try:
                # Compute correlation within each category
                category_corrs = {}
                for category in df[cat_col].dropna().unique()[:5]:  # Limit categories
                    cat_data = df[df[cat_col] == category]
                    if len(cat_data) > 3:
                        corr, _ = pearsonr(cat_data[num_col].dropna(),
                                         cat_data.loc[cat_data[num_col].notna(), target_column])
                        if not np.isnan(corr):
                            category_corrs[str(category)] = round(float(corr), 4)
                
                # Check if correlation varies significantly across categories
                if len(category_corrs) > 1:
                    corr_variance = np.var(list(category_corrs.values()))
                    if corr_variance > 0.1:  # Significant variation
                        interactions.append({
                            "type": "categorical_numeric",
                            "features": [cat_col, num_col],
                            "correlation_by_category": category_corrs,
                            "correlation_variance": round(float(corr_variance), 4)
                        })
            except:
                continue
    
    # Sort by importance and limit
    interactions = sorted(interactions, 
                         key=lambda x: abs(x.get("correlation", 0)) if "correlation" in x 
                         else x.get("correlation_variance", 0), 
                         reverse=True)[:max_interactions]
    
    return interactions


def compute_conditional_analysis(
    df: pd.DataFrame,
    target_column: str,
    feature_columns: List[str],
    numeric_cols: List[str],
    categorical_cols: List[str]
) -> Dict:
    """Performs conditional analysis to identify patterns."""
    conditional = {
        "high_target_conditions": [],
        "low_target_conditions": [],
        "anomaly_conditions": []
    }
    
    target_mean = df[target_column].mean()
    target_std = df[target_column].std()
    high_threshold = target_mean + target_std
    low_threshold = target_mean - target_std
    
    # Identify conditions for high target values
    high_target_df = df[df[target_column] > high_threshold]
    if len(high_target_df) > 0:
        for col in feature_columns[:10]:  # Limit analysis
            try:
                if col in numeric_cols:
                    high_mean = high_target_df[col].mean()
                    overall_mean = df[col].mean()
                    if abs(high_mean - overall_mean) > df[col].std() * 0.5:
                        conditional["high_target_conditions"].append({
                            "feature": col,
                            "high_target_mean": round(float(high_mean), 3),
                            "overall_mean": round(float(overall_mean), 3),
                            "difference": round(float(high_mean - overall_mean), 3)
                        })
                elif col in categorical_cols:
                    high_dist = high_target_df[col].value_counts(normalize=True)
                    overall_dist = df[col].value_counts(normalize=True)
                    
                    # Find categories over-represented in high target
                    for cat in high_dist.index:
                        if cat in overall_dist.index:
                            diff = high_dist[cat] - overall_dist[cat]
                            if diff > 0.1:  # At least 10% more
                                conditional["high_target_conditions"].append({
                                    "feature": col,
                                    "category": str(cat),
                                    "high_target_pct": round(float(high_dist[cat] * 100), 2),
                                    "overall_pct": round(float(overall_dist[cat] * 100), 2),
                                    "difference": round(float(diff * 100), 2)
                                })
            except:
                continue
    
    # Identify conditions for low target values
    low_target_df = df[df[target_column] < low_threshold]
    if len(low_target_df) > 0:
        for col in feature_columns[:10]:
            try:
                if col in numeric_cols:
                    low_mean = low_target_df[col].mean()
                    overall_mean = df[col].mean()
                    if abs(low_mean - overall_mean) > df[col].std() * 0.5:
                        conditional["low_target_conditions"].append({
                            "feature": col,
                            "low_target_mean": round(float(low_mean), 3),
                            "overall_mean": round(float(overall_mean), 3),
                            "difference": round(float(low_mean - overall_mean), 3)
                        })
            except:
                continue
    
    # Limit results
    conditional["high_target_conditions"] = conditional["high_target_conditions"][:5]
    conditional["low_target_conditions"] = conditional["low_target_conditions"][:5]
    
    return conditional


def identify_root_causes(
    feature_importance: Dict[str, float],
    correlations: Dict,
    categorical_impacts: Dict,
    interactions: List[Dict],
    min_threshold: float
) -> List[Dict]:
    """Identifies root causes based on all analyses."""
    root_causes = []
    
    # From feature importance
    top_features = list(feature_importance.items())[:5]
    for feat, importance in top_features:
        if importance >= min_threshold:
            evidence = []
            
            # Add correlation evidence if numeric
            if feat in correlations.get("pearson", {}):
                corr_info = correlations["pearson"][feat]
                evidence.append({
                    "type": "correlation",
                    "strength": abs(corr_info["correlation"]),
                    "direction": "positive" if corr_info["correlation"] > 0 else "negative",
                    "significant": corr_info["significant"]
                })
            
            # Add categorical impact evidence
            if feat in categorical_impacts:
                impact_info = categorical_impacts[feat]
                evidence.append({
                    "type": "categorical_impact",
                    "f_ratio": impact_info["f_ratio"],
                    "most_impactful": impact_info["most_impactful_category"]
                })
            
            root_causes.append({
                "feature": feat,
                "importance_score": importance,
                "evidence": evidence,
                "confidence": "high" if importance > 0.5 else "medium" if importance > 0.3 else "low"
            })
    
    # Add significant interactions as potential root causes
    for interaction in interactions[:3]:
        if interaction.get("significant", False) or interaction.get("correlation_variance", 0) > 0.2:
            root_causes.append({
                "feature": "interaction",
                "features": interaction["features"],
                "importance_score": abs(interaction.get("correlation", 0)),
                "evidence": [{"type": "interaction", "details": interaction}],
                "confidence": "medium"
            })
    
    # Sort by importance
    root_causes = sorted(root_causes, key=lambda x: x["importance_score"], reverse=True)
    
    return root_causes


def generate_rca_insights_summary(
    target_summary: Dict,
    feature_importance: Dict[str, float],
    correlations: Dict,
    categorical_impacts: Dict,
    root_causes: List[Dict]
) -> str:
    """Generates textual summary of root cause analysis insights."""
    summary_parts = []
    
    # Target summary
    target_mean = target_summary["mean"]
    target_std = target_summary["std"]
    summary_parts.append(
        f"Root cause analysis performed on target variable with mean={target_mean:.2f}, "
        f"std={target_std:.2f} (n={target_summary['count']})."
    )
    
    # Top root causes
    if root_causes:
        top_causes = root_causes[:3]
        cause_names = []
        for cause in top_causes:
            if "features" in cause:
                cause_names.append(f"interaction between {', '.join(cause['features'])}")
            else:
                cause_names.append(cause["feature"])
        
        summary_parts.append(
            f"Top identified root causes: {', '.join(cause_names)}."
        )
    
    # Strongest correlations
    if correlations.get("top_correlated"):
        top_corr = correlations["top_correlated"][0]
        summary_parts.append(
            f"Strongest correlation: {top_corr['feature']} "
            f"(r={top_corr['pearson_corr']:.3f}, "
            f"{'significant' if top_corr['significant'] else 'not significant'})."
        )
    
    # Categorical impacts
    if categorical_impacts:
        top_cat = max(categorical_impacts.items(), 
                     key=lambda x: x[1].get("f_ratio", 0))
        if top_cat[1].get("f_ratio", 0) > 1.0:
            summary_parts.append(
                f"Strongest categorical impact: {top_cat[0]} "
                f"(F-ratio={top_cat[1]['f_ratio']:.2f})."
            )
    
    return " ".join(summary_parts)


#============== Time Series Forecasting Utilities ==============

def aggregate_time_series(
    df: pd.DataFrame,
    date_column: str,
    value_columns: List[str],
    freq: str,
    method: str = 'mean'
) -> pd.DataFrame:
    """
    Aggregates time series data by a specified frequency.
    
    Parameters
    ----------
    df : pd.DataFrame
        DataFrame with time series data
    date_column : str
        Name of the date column
    value_columns : List[str]
        List of value columns to aggregate
    freq : str
        Pandas frequency string (e.g., 'D', 'W', 'M', 'Q', 'Y')
    method : str, default='mean'
        Aggregation method: 'mean', 'sum', 'median', 'min', 'max', 'count'
        
    Returns
    -------
    pd.DataFrame
        Aggregated DataFrame with date_column and aggregated value_columns
    """
    # Ensure date column is datetime and set as index
    df_agg = df.copy()
    df_agg[date_column] = pd.to_datetime(df_agg[date_column])
    df_agg = df_agg.set_index(date_column)
    
    # Define aggregation dictionary
    agg_dict = {}
    for col in value_columns:
        if col in df_agg.columns:
            agg_dict[col] = method
    
    # Perform aggregation
    if method == 'mean':
        df_agg = df_agg[value_columns].resample(freq).mean()
    elif method == 'sum':
        df_agg = df_agg[value_columns].resample(freq).sum()
    elif method == 'median':
        df_agg = df_agg[value_columns].resample(freq).median()
    elif method == 'min':
        df_agg = df_agg[value_columns].resample(freq).min()
    elif method == 'max':
        df_agg = df_agg[value_columns].resample(freq).max()
    elif method == 'count':
        df_agg = df_agg[value_columns].resample(freq).count()
    else:
        raise ValueError(f"Unknown aggregation method: {method}. Use 'mean', 'sum', 'median', 'min', 'max', or 'count'")
    
    # Reset index to get date_column back
    df_agg = df_agg.reset_index()
    df_agg = df_agg.dropna()
    
    return df_agg


def detect_date_frequency(dates: np.ndarray) -> Dict:
    """
    Detects the frequency of dates in a time series and determines if dates are regular.
    Returns information about date frequency and regularity.
    
    Parameters
    ----------
    dates : np.ndarray
        Array of datetime values
        
    Returns
    -------
    Dict
        Dictionary containing:
        - is_regular : bool indicating if dates are regular
        - frequency : detected frequency (median interval)
        - frequency_std : standard deviation of intervals (lower = more regular)
        - median_interval : median time interval between dates
        - mean_interval : mean time interval between dates
    """
    if len(dates) < 2:
        return {
            "is_regular": False,
            "frequency": None,
            "frequency_std": None,
            "median_interval": None,
            "mean_interval": None
        }
    
    # Convert to pandas Series for easier manipulation
    dates_series = pd.Series(pd.to_datetime(dates))
    dates_sorted = dates_series.sort_values()
    
    # Calculate intervals between consecutive dates
    intervals = dates_sorted.diff().dropna()
    
    if len(intervals) == 0:
        return {
            "is_regular": False,
            "frequency": None,
            "frequency_std": None,
            "median_interval": None,
            "mean_interval": None
        }
    
    median_interval = intervals.median()
    mean_interval = intervals.mean()
    std_interval = intervals.std()
    
    # Consider dates regular if coefficient of variation < 0.3 (30% variation)
    cv = float(std_interval / mean_interval) if mean_interval.total_seconds() > 0 else float('inf')
    is_regular = cv < 0.3
    
    return {
        "is_regular": bool(is_regular),
        "frequency": median_interval,
        "frequency_std": std_interval,
        "median_interval": median_interval,
        "mean_interval": mean_interval,
        "coefficient_of_variation": cv
    }


def generate_future_dates(last_date: pd.Timestamp, forecast_horizon: int, date_frequency_info: Dict) -> List:
    """
    Generates future dates for forecasting based on detected frequency.
    Handles both regular and irregular date patterns.
    
    Parameters
    ----------
    last_date : pd.Timestamp
        Last date in the time series
    forecast_horizon : int
        Number of periods to forecast
    date_frequency_info : Dict
        Dictionary from detect_date_frequency() containing frequency information
        
    Returns
    -------
    List
        List of future dates
    """
    if not date_frequency_info.get("frequency"):
        # Fallback: use 1 day intervals
        return [last_date + pd.Timedelta(days=i+1) for i in range(forecast_horizon)]
    
    frequency = date_frequency_info["frequency"]
    future_dates = []
    
    for i in range(forecast_horizon):
        future_dates.append(last_date + frequency * (i + 1))
    
    return future_dates


def analyze_trend(ts: np.ndarray, dates: np.ndarray) -> Dict:
    """Analyzes trend in time series."""
    n = len(ts)
    if n < 2:
        return {"trend": "insufficient_data"}
    
    # Ensure numeric type
    ts = np.asarray(ts, dtype=np.float64)
    
    # Simple linear regression for trend
    x = np.arange(n, dtype=np.float64)
    slope, intercept = np.polyfit(x, ts, 1)
    
    # Calculate R-squared
    y_pred = slope * x + intercept
    ss_res = np.sum((ts - y_pred) ** 2)
    ss_tot = np.sum((ts - np.mean(ts)) ** 2)
    r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
    
    trend_direction = "increasing" if slope > 0 else "decreasing" if slope < 0 else "stable"
    
    return {
        "trend": trend_direction,
        "slope": float(slope),
        "intercept": float(intercept),
        "r_squared": float(r_squared),
        "strength": "strong" if abs(r_squared) > 0.7 else "moderate" if abs(r_squared) > 0.3 else "weak"
    }


def analyze_seasonality(ts: np.ndarray, dates: np.ndarray, seasonal_period: Optional[int]) -> Dict:
    """Analyzes seasonality in time series."""
    n = len(ts)
    
    if n < 12:
        return {"has_seasonality": False, "reason": "insufficient_data"}
    
    # Ensure numeric type
    ts = np.asarray(ts, dtype=np.float64)
    
    # Auto-detect seasonal period if not provided
    if seasonal_period is None:
        # Try common periods
        for period in [12, 7, 4, 52]:
            if n >= period * 2:
                seasonal_period = period
                break
        else:
            seasonal_period = min(12, n // 2)
    
    if n < seasonal_period * 2:
        return {"has_seasonality": False, "reason": "insufficient_data_for_period"}
    
    # Simple seasonality test: check variance within seasons
    try:
        # Reshape to seasonal periods
        n_seasons = n // seasonal_period
        if n_seasons < 2:
            return {"has_seasonality": False, "reason": "insufficient_seasons"}
        
        seasonal_data = ts[:n_seasons * seasonal_period].reshape(n_seasons, seasonal_period)
        seasonal_means = np.mean(seasonal_data, axis=0)
        seasonal_std = float(np.std(seasonal_means))
        overall_std = float(np.std(ts))
        
        # If seasonal variation is significant relative to overall variation
        has_seasonality = seasonal_std > 0.1 * overall_std
        
        return {
            "has_seasonality": bool(has_seasonality),
            "period": seasonal_period,
            "n_seasons": n_seasons,
            "seasonal_strength": float(seasonal_std / overall_std) if overall_std > 0 else 0.0,
            "seasonal_pattern": [float(x) for x in seasonal_means.tolist()] if has_seasonality else None
        }
    except Exception as e:
        logging.exception("Error in seasonality analysis")
        return {"has_seasonality": False, "reason": f"analysis_failed: {str(e)}"}


def forecast_moving_average(ts_train: np.ndarray, ts_test: Optional[np.ndarray], forecast_horizon: int) -> Tuple[Dict, Optional[Dict]]:
    """Forecasts using moving average method."""
    # Ensure numeric type
    ts_train = np.asarray(ts_train, dtype=np.float64)
    if ts_test is not None:
        ts_test = np.asarray(ts_test, dtype=np.float64)
    
    # Use last N values for moving average (window = min(5, len(ts_train)))
    window = min(5, len(ts_train))
    last_values = ts_train[-window:]
    ma_value = float(np.mean(last_values))
    
    # Calculate standard deviation for confidence intervals
    std_dev = float(np.std(last_values))
    
    # Forecast: repeat the moving average
    forecast = np.full(forecast_horizon, ma_value, dtype=np.float64)
    
    # Calculate confidence intervals (95% confidence: ±1.96 * std)
    confidence_multiplier = 1.96
    forecast_lower = (forecast - confidence_multiplier * std_dev).tolist()
    forecast_upper = (forecast + confidence_multiplier * std_dev).tolist()
    
    forecast_dict = {
        "method": "moving_average",
        "window": window,
        "forecast_values": forecast.tolist(),
        "forecast_lower": forecast_lower,
        "forecast_upper": forecast_upper,
        "forecast_mean": float(ma_value)
    }
    
    # Calculate accuracy if test data available
    accuracy = None
    if ts_test is not None and len(ts_test) > 0:
        test_forecast = np.asarray(forecast[:len(ts_test)], dtype=np.float64)
        ts_test_array = np.asarray(ts_test, dtype=np.float64)
        mae = float(np.mean(np.abs(ts_test_array - test_forecast)))
        rmse = float(np.sqrt(np.mean((ts_test_array - test_forecast) ** 2)))
        
        # Calculate MAPE only for values that are not too close to zero
        # Filter out values where abs(actual) < threshold to avoid MAPE explosion
        threshold = max(np.abs(ts_test_array).max() * 0.01, 1e-6)  # 1% of max value or 1e-6, whichever is larger
        valid_mask = np.abs(ts_test_array) >= threshold
        
        if valid_mask.sum() > 0:
            # Calculate MAPE only on valid values
            valid_actual = ts_test_array[valid_mask]
            valid_forecast = test_forecast[valid_mask]
            denominator = np.abs(valid_actual) + 1e-10
            mape = float(np.mean(np.abs((valid_actual - valid_forecast) / denominator)) * 100)
        else:
            # If all values are too small, set MAPE to None
            mape = None
        
        accuracy = {
            "mae": mae,
            "rmse": rmse,
            "mape": mape
        }
    
    return forecast_dict, accuracy


def forecast_linear_trend(ts_train: np.ndarray, ts_test: Optional[np.ndarray], forecast_horizon: int) -> Tuple[Dict, Optional[Dict]]:
    """Forecasts using linear trend method."""
    # Ensure numeric type
    ts_train = np.asarray(ts_train, dtype=np.float64)
    if ts_test is not None:
        ts_test = np.asarray(ts_test, dtype=np.float64)
    
    n = len(ts_train)
    x = np.arange(n, dtype=np.float64)
    
    # Linear regression
    slope, intercept = np.polyfit(x, ts_train, 1)
    
    # Calculate residuals for confidence intervals
    fitted_values = slope * x + intercept
    residuals = ts_train - fitted_values
    std_dev = float(np.std(residuals))
    
    # Forecast: extend the trend
    forecast_x = np.arange(n, n + forecast_horizon, dtype=np.float64)
    forecast = slope * forecast_x + intercept
    
    # Calculate confidence intervals (95% confidence: ±1.96 * std)
    # Intervals widen with distance from training data
    confidence_multiplier = 1.96
    forecast_lower = (forecast - confidence_multiplier * std_dev * (1 + forecast_x / n)).tolist()
    forecast_upper = (forecast + confidence_multiplier * std_dev * (1 + forecast_x / n)).tolist()
    
    forecast_dict = {
        "method": "linear_trend",
        "slope": float(slope),
        "intercept": float(intercept),
        "forecast_values": forecast.tolist(),
        "forecast_lower": forecast_lower,
        "forecast_upper": forecast_upper,
        "forecast_mean": float(np.mean(forecast))
    }
    
    # Calculate accuracy if test data available
    accuracy = None
    if ts_test is not None and len(ts_test) > 0:
        test_forecast = np.asarray(forecast[:len(ts_test)], dtype=np.float64)
        ts_test_array = np.asarray(ts_test, dtype=np.float64)
        mae = float(np.mean(np.abs(ts_test_array - test_forecast)))
        rmse = float(np.sqrt(np.mean((ts_test_array - test_forecast) ** 2)))
        
        # Calculate MAPE only for values that are not too close to zero
        # Filter out values where abs(actual) < threshold to avoid MAPE explosion
        threshold = max(np.abs(ts_test_array).max() * 0.01, 1e-6)  # 1% of max value or 1e-6, whichever is larger
        valid_mask = np.abs(ts_test_array) >= threshold
        
        if valid_mask.sum() > 0:
            # Calculate MAPE only on valid values
            valid_actual = ts_test_array[valid_mask]
            valid_forecast = test_forecast[valid_mask]
            denominator = np.abs(valid_actual) + 1e-10
            mape = float(np.mean(np.abs((valid_actual - valid_forecast) / denominator)) * 100)
        else:
            # If all values are too small, set MAPE to None
            mape = None
        
        accuracy = {
            "mae": mae,
            "rmse": rmse,
            "mape": mape
        }
    
    return forecast_dict, accuracy


def forecast_exponential_smoothing(ts_train: np.ndarray, ts_test: Optional[np.ndarray], forecast_horizon: int) -> Tuple[Dict, Optional[Dict]]:
    """Forecasts using exponential smoothing."""
    # Ensure numeric type
    ts_train = np.asarray(ts_train, dtype=np.float64)
    if ts_test is not None:
        ts_test = np.asarray(ts_test, dtype=np.float64)
    
    try:
        from statsmodels.tsa.holtwinters import SimpleExpSmoothing
        
        # Fit exponential smoothing
        model = SimpleExpSmoothing(ts_train)
        fitted_model = model.fit(optimized=True)
        
        # Forecast
        forecast = fitted_model.forecast(forecast_horizon)
        
        # Calculate residuals for confidence intervals
        fitted_values = fitted_model.fittedvalues
        residuals = ts_train - fitted_values
        std_dev = float(np.std(residuals))
        
        # Calculate confidence intervals (95% confidence: ±1.96 * std)
        confidence_multiplier = 1.96
        forecast_lower = (forecast - confidence_multiplier * std_dev).tolist()
        forecast_upper = (forecast + confidence_multiplier * std_dev).tolist()
        
        forecast_dict = {
            "method": "exponential_smoothing",
            "alpha": float(fitted_model.params.get("smoothing_level", 0.5)),
            "forecast_values": forecast.tolist(),
            "forecast_lower": forecast_lower,
            "forecast_upper": forecast_upper,
            "forecast_mean": float(np.mean(forecast))
        }
        
        # Calculate accuracy if test data available
        accuracy = None
        if ts_test is not None and len(ts_test) > 0:
            test_forecast = np.asarray(forecast[:len(ts_test)], dtype=np.float64)
            ts_test_array = np.asarray(ts_test, dtype=np.float64)
            mae = float(np.mean(np.abs(ts_test_array - test_forecast)))
            rmse = float(np.sqrt(np.mean((ts_test_array - test_forecast) ** 2)))
            
            # Calculate MAPE only for values that are not too close to zero
            # Filter out values where abs(actual) < threshold to avoid MAPE explosion
            threshold = max(np.abs(ts_test_array).max() * 0.01, 1e-6)  # 1% of max value or 1e-6, whichever is larger
            valid_mask = np.abs(ts_test_array) >= threshold
            
            if valid_mask.sum() > 0:
                # Calculate MAPE only on valid values
                valid_actual = ts_test_array[valid_mask]
                valid_forecast = test_forecast[valid_mask]
                denominator = np.abs(valid_actual) + 1e-10
                mape = float(np.mean(np.abs((valid_actual - valid_forecast) / denominator)) * 100)
            else:
                # If all values are too small, set MAPE to None
                mape = None
            
            accuracy = {
                "mae": mae,
                "rmse": rmse,
                "mape": mape
            }
        
        return forecast_dict, accuracy
    except Exception as e:
        logging.exception("Error in forecast_moving_average, falling back to simple exponential smoothing")
        # Fallback to simple exponential smoothing
        alpha = 0.3
        forecast = []
        last_value = ts_train[-1]
        for _ in range(forecast_horizon):
            forecast.append(last_value)
        
        # Calculate standard deviation for confidence intervals
        std_dev = float(np.std(ts_train))
        confidence_multiplier = 1.96
        forecast_lower = [v - confidence_multiplier * std_dev for v in forecast]
        forecast_upper = [v + confidence_multiplier * std_dev for v in forecast]
        
        forecast_dict = {
            "method": "exponential_smoothing",
            "alpha": alpha,
            "forecast_values": forecast,
            "forecast_lower": forecast_lower,
            "forecast_upper": forecast_upper,
            "forecast_mean": float(np.mean(forecast))
        }
        
        return forecast_dict, None


def forecast_seasonal_decompose(ts_train: np.ndarray, ts_test: Optional[np.ndarray], forecast_horizon: int, period: int) -> Tuple[Dict, Optional[Dict]]:
    """Forecasts using seasonal decomposition."""
    # Ensure numeric type
    ts_train = np.asarray(ts_train, dtype=np.float64)
    if ts_test is not None:
        ts_test = np.asarray(ts_test, dtype=np.float64)
    
    try:
        from statsmodels.tsa.seasonal import seasonal_decompose
        
        # Decompose time series
        decomposition = seasonal_decompose(ts_train, model='additive', period=period)
        
        # Get components
        trend = decomposition.trend
        seasonal = decomposition.seasonal
        residual = decomposition.resid
        
        # Forecast: extend trend and repeat seasonal pattern
        # Simple approach: use last trend value and repeat seasonal pattern
        last_trend = trend.dropna().iloc[-1] if hasattr(trend, 'iloc') else trend[~np.isnan(trend)][-1]
        seasonal_pattern = seasonal[-period:].values if hasattr(seasonal, 'values') else seasonal[-period:]
        
        forecast = []
        for i in range(forecast_horizon):
            seasonal_idx = i % len(seasonal_pattern)
            forecast.append(float(last_trend + seasonal_pattern[seasonal_idx]))
        
        # Calculate residuals for confidence intervals
        residual_std = float(np.std(residual.dropna())) if hasattr(residual, 'dropna') else float(np.std(residual[~np.isnan(residual)]))
        confidence_multiplier = 1.96
        forecast_lower = [v - confidence_multiplier * residual_std for v in forecast]
        forecast_upper = [v + confidence_multiplier * residual_std for v in forecast]
        
        forecast_dict = {
            "method": "seasonal_decompose",
            "period": period,
            "forecast_values": forecast,
            "forecast_lower": forecast_lower,
            "forecast_upper": forecast_upper,
            "forecast_mean": float(np.mean(forecast))
        }
        
        # Calculate accuracy if test data available
        accuracy = None
        if ts_test is not None and len(ts_test) > 0:
            test_forecast = np.asarray(forecast[:len(ts_test)], dtype=np.float64)
            ts_test_array = np.asarray(ts_test, dtype=np.float64)
            mae = float(np.mean(np.abs(ts_test_array - test_forecast)))
            rmse = float(np.sqrt(np.mean((ts_test_array - test_forecast) ** 2)))
            
            # Calculate MAPE only for values that are not too close to zero
            # Filter out values where abs(actual) < threshold to avoid MAPE explosion
            threshold = max(np.abs(ts_test_array).max() * 0.01, 1e-6)  # 1% of max value or 1e-6, whichever is larger
            valid_mask = np.abs(ts_test_array) >= threshold
            
            if valid_mask.sum() > 0:
                # Calculate MAPE only on valid values
                valid_actual = ts_test_array[valid_mask]
                valid_forecast = test_forecast[valid_mask]
                denominator = np.abs(valid_actual) + 1e-10
                mape = float(np.mean(np.abs((valid_actual - valid_forecast) / denominator)) * 100)
            else:
                # If all values are too small, set MAPE to None
                mape = None
            
            accuracy = {
                "mae": mae,
                "rmse": rmse,
                "mape": mape
            }
        
        return forecast_dict, accuracy
    except Exception as e:
        logging.exception("Error in forecast_exponential_smoothing, using fallback")
        # Fallback
        forecast = np.full(forecast_horizon, np.mean(ts_train), dtype=np.float64)
        std_dev = float(np.std(ts_train))
        confidence_multiplier = 1.96
        forecast_lower = (forecast - confidence_multiplier * std_dev).tolist()
        forecast_upper = (forecast + confidence_multiplier * std_dev).tolist()
        
        forecast_dict = {
            "method": "seasonal_decompose",
            "period": period,
            "forecast_values": forecast.tolist(),
            "forecast_lower": forecast_lower,
            "forecast_upper": forecast_upper,
            "forecast_mean": float(np.mean(forecast)),
            "error": str(e)
        }
        return forecast_dict, None


def forecast_prophet(ts_train: np.ndarray, dates_train: np.ndarray, ts_test: Optional[np.ndarray], dates_test: Optional[np.ndarray], forecast_horizon: int) -> Tuple[Dict, Optional[Dict]]:
    """Forecasts using Facebook Prophet."""
    # Ensure numeric type
    ts_train = np.asarray(ts_train, dtype=np.float64)
    if ts_test is not None:
        ts_test = np.asarray(ts_test, dtype=np.float64)
    
    try:
        from prophet import Prophet
        import pandas as pd
        
        # Prepare data for Prophet (requires 'ds' and 'y' columns)
        df_train = pd.DataFrame({
            'ds': pd.to_datetime(dates_train),
            'y': ts_train.astype(np.float64)
        })
        
        # Fit Prophet model
        model = Prophet(
            yearly_seasonality='auto',
            weekly_seasonality='auto',
            daily_seasonality=False,
            seasonality_mode='additive'
        )
        model.fit(df_train)
        
        # Create future dates for forecast
        future = model.make_future_dataframe(periods=forecast_horizon)
        
        # Generate forecast
        forecast = model.predict(future)
        
        # Extract forecast values (only the future periods)
        forecast_values = forecast['yhat'].tail(forecast_horizon).values.tolist()
        
        # Get uncertainty intervals
        forecast_lower = forecast['yhat_lower'].tail(forecast_horizon).values.tolist()
        forecast_upper = forecast['yhat_upper'].tail(forecast_horizon).values.tolist()
        
        forecast_dict = {
            "method": "prophet",
            "forecast_values": [float(v) for v in forecast_values],
            "forecast_lower": [float(v) for v in forecast_lower],
            "forecast_upper": [float(v) for v in forecast_upper],
            "forecast_mean": float(np.mean(forecast_values))
        }
        
        # Calculate accuracy if test data available
        accuracy = None
        if ts_test is not None and len(ts_test) > 0:
            # Use the last len(ts_test) forecast values for comparison
            test_forecast = np.asarray(forecast_values[:len(ts_test)], dtype=np.float64)
            ts_test_array = np.asarray(ts_test, dtype=np.float64)
            mae = float(np.mean(np.abs(ts_test_array - test_forecast)))
            rmse = float(np.sqrt(np.mean((ts_test_array - test_forecast) ** 2)))
            
            # Calculate MAPE only for values that are not too close to zero
            # Filter out values where abs(actual) < threshold to avoid MAPE explosion
            threshold = max(np.abs(ts_test_array).max() * 0.01, 1e-6)  # 1% of max value or 1e-6, whichever is larger
            valid_mask = np.abs(ts_test_array) >= threshold
            
            if valid_mask.sum() > 0:
                # Calculate MAPE only on valid values
                valid_actual = ts_test_array[valid_mask]
                valid_forecast = test_forecast[valid_mask]
                denominator = np.abs(valid_actual) + 1e-10
                mape = float(np.mean(np.abs((valid_actual - valid_forecast) / denominator)) * 100)
            else:
                # If all values are too small, set MAPE to None
                mape = None
            
            accuracy = {
                "mae": mae,
                "rmse": rmse,
                "mape": mape
            }
        
        return forecast_dict, accuracy
    except ImportError:
        # Prophet not installed, return None
        return None, None
    except Exception as e:
        logging.exception("Error in forecast_prophet, using fallback")
        # Fallback if Prophet fails
        mean_value = float(np.mean(ts_train))
        std_dev = float(np.std(ts_train))
        confidence_multiplier = 1.96
        forecast_values = [mean_value] * forecast_horizon
        forecast_lower = [mean_value - confidence_multiplier * std_dev] * forecast_horizon
        forecast_upper = [mean_value + confidence_multiplier * std_dev] * forecast_horizon
        
        forecast_dict = {
            "method": "prophet",
            "forecast_values": forecast_values,
            "forecast_lower": forecast_lower,
            "forecast_upper": forecast_upper,
            "forecast_mean": mean_value,
            "error": str(e)
        }
        return forecast_dict, None


def determine_best_forecast_method(forecast_accuracy: Dict) -> Optional[Dict]:
    """Determines the best forecasting method based on accuracy metrics."""
    if not forecast_accuracy:
        return None
    
    # Find method with lowest MAPE (or MAE if MAPE not available or too large)
    best_method = None
    best_score = float('inf')
    best_method_name = None
    
    for method_name, accuracy in forecast_accuracy.items():
        # Prefer MAPE, but skip if None or too large, fallback to MAE
        mape = accuracy.get("mape")
        if mape is not None and mape < 10000:  # Reasonable MAPE threshold
            score = mape
        else:
            score = accuracy.get("mae", float('inf'))
        
        if score < best_score:
            best_score = score
            best_method_name = method_name
    
    if best_method_name:
        return {
            "method": best_method_name,
            "accuracy": forecast_accuracy[best_method_name]
        }
    
    return None


def generate_forecast_insights(forecasts: Dict, forecast_accuracy: Dict, best_method: Optional[Dict], time_series_summary: Dict) -> Dict:
    """Generates insights about the forecasts."""
    insights = {
        "n_methods": len(forecasts),
        "best_method": best_method.get("method") if best_method else None,
        "forecast_direction": None,
        "forecast_confidence": "low"
    }
    
    if forecasts:
        # Determine forecast direction (average of all methods)
        all_forecasts = []
        for method_forecast in forecasts.values():
            if "forecast_values" in method_forecast:
                all_forecasts.append(method_forecast["forecast_values"])
        
        if all_forecasts:
            avg_forecast = np.mean([f[0] for f in all_forecasts if len(f) > 0])
            last_value = time_series_summary.get("mean", 0)
            
            if avg_forecast > last_value * 1.05:
                insights["forecast_direction"] = "increasing"
            elif avg_forecast < last_value * 0.95:
                insights["forecast_direction"] = "decreasing"
            else:
                insights["forecast_direction"] = "stable"
            
            # Confidence based on agreement between methods
            if len(forecasts) > 1:
                forecast_std = np.std([f[0] for f in all_forecasts if len(f) > 0])
                if forecast_std < last_value * 0.1:
                    insights["forecast_confidence"] = "high"
                elif forecast_std < last_value * 0.2:
                    insights["forecast_confidence"] = "moderate"
    
    return insights


def generate_time_series_insights_summary(
    time_series_summary: Dict,
    trend_analysis: Dict,
    seasonality_analysis: Dict,
    forecasts: Dict,
    forecast_accuracy: Dict,
    best_method: Optional[Dict]
) -> str:
    """Generates textual summary of time series forecasting insights."""
    summary_parts = []
    
    # Time series overview
    summary_parts.append(
        f"Time series analysis of {time_series_summary['n_observations']} observations "
        f"from {time_series_summary['start_date']} to {time_series_summary['end_date']}."
    )
    
    # Trend
    if trend_analysis.get("trend"):
        trend = trend_analysis["trend"]
        strength = trend_analysis.get("strength", "unknown")
        summary_parts.append(f"Trend: {strength} {trend} trend detected.")
    
    # Seasonality
    if seasonality_analysis.get("has_seasonality"):
        period = seasonality_analysis.get("period", "unknown")
        summary_parts.append(f"Seasonality: detected with period {period}.")
    else:
        summary_parts.append("Seasonality: no significant seasonality detected.")
    
    # Forecasts
    if forecasts:
        summary_parts.append(f"Forecasts generated using {len(forecasts)} method(s).")
    
    # Best method
    if best_method:
        method_name = best_method["method"]
        accuracy = best_method.get("accuracy", {})
        mape = accuracy.get("mape")
        if mape is not None:
            summary_parts.append(
                f"Best forecasting method: {method_name} "
                f"(MAPE: {mape:.2f}%)."
            )
        else:
            mae = accuracy.get("mae", "N/A")
            if mae != "N/A":
                summary_parts.append(
                    f"Best forecasting method: {method_name} "
                    f"(MAE: {mae:.2f})."
                )
            else:
                summary_parts.append(
                    f"Best forecasting method: {method_name}."
                )
    
    return " ".join(summary_parts)

