import json
import logging
import pandas as pd
import numpy as np
import sklearn

from sklearn.mixture import GaussianMixture

from dataiku.base.utils import package_is_at_least
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.doctor.clustering.common import prepare_multiframe
from sklearn.cluster import MiniBatchKMeans, AgglomerativeClustering, DBSCAN, SpectralClustering
if package_is_at_least(sklearn, "1.3"):
    from sklearn.cluster import HDBSCAN

from dataiku.doctor.clustering.anomaly_detection import DkuIsolationForest
from dataiku.doctor.clustering.two_step_clustering import TwoStepClustering
from dataiku.doctor.prediction.common import get_initial_intrinsic_perf_data
from dataiku.doctor.utils.estimator import set_column_labels
from dataiku.doctor.utils.skcompat import get_kmeans_estimator, get_kmeans_n_init_value, get_minibatch_n_init_value

logger = logging.getLogger(__name__)


def scikit_model(modeling_params):
    code = modeling_params['scikit_clf']
    ctx = {"n_clusters": modeling_params.get("k", None)}
    exec(code, ctx)

    clf = ctx.get("clf", None)

    if clf is None:
        raise Exception("No variable 'clf' defined in Custom Python model")

    return clf


def clustering_model_from_params(modeling_params, rows=0):
    algorithm = modeling_params['algorithm']
    seed = modeling_params.get("seed") # None means random
    n_jobs = modeling_params.get("n_jobs", 2)

    k = int(modeling_params.get("k", 0))
    if algorithm == "SCIKIT_MODEL":
        return scikit_model(modeling_params)
    elif algorithm == 'KMEANS':
        logger.info("KMEANS k=%d n_jobs=%d" % (k, n_jobs))
        return get_kmeans_estimator(n_clusters=k, n_jobs=n_jobs, random_state=seed, n_init=get_kmeans_n_init_value(modeling_params))
    elif algorithm == 'MiniBatchKMeans':
        return MiniBatchKMeans(n_clusters=k, random_state=seed, n_init=get_minibatch_n_init_value(modeling_params))
    elif algorithm == 'SPECTRAL':
        return SpectralClustering(n_clusters=k,
                                  affinity=modeling_params["affinity"],
                                  coef0=modeling_params.get("coef0"),
                                  gamma=modeling_params.get("gamma"),
                                  random_state=seed)
    elif algorithm == 'WARD':
        return AgglomerativeClustering(n_clusters=k)
    elif algorithm == 'DBSCAN':
        return DBSCAN(eps=float(modeling_params["epsilon"]),
                      min_samples=int(float(modeling_params["min_sample_ratio"]) * rows))
    elif algorithm == 'HDBSCAN':
        if package_is_at_least(sklearn, "1.3"):
            if modeling_params["min_cluster_size_ratio"] * rows < 2:
                diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS, "The minimum number of samples per cluster was computed to be"
                                                                                             " less than 2. The value was set to 2. As a result all values of "
                                                                                             "min_cluster_size_ratio that are below this limit will yield the "
                                                                                             "same clustering")
                return HDBSCAN(min_cluster_size=2)
            return HDBSCAN(min_cluster_size=int(float(modeling_params["min_cluster_size_ratio"]) * rows))
        else:
            raise Exception("HDBScan algorithm requires a version of scikit-learn greater or equal to 1.3")
    elif algorithm == 'GMM':
        return GaussianMixture(n_components=k, random_state=seed, max_iter=modeling_params["max_iterations"])
    elif algorithm == 'PY_TWO_STEP':
        return TwoStepClustering(k, int(modeling_params["ts_kmeans_k"]), int(modeling_params["max_iterations"]), seed)
    elif algorithm == 'PY_ISOLATION_FOREST':
        par = modeling_params["isolation_forest"]
        contamination = "auto" if par["use_auto_contamination"] else par["contamination"]
        return DkuIsolationForest(n_estimators=par["n_estimators"], max_samples=par["max_samples"],
                                  max_features=par["max_features"], contamination=contamination,
                                  bootstrap=par["bootstrap"], max_anomalies=par["max_anomalies"],
                                  random_state=par["seed"], n_jobs=par["n_jobs"])


class ClusteringModelInspector(object):
    def __init__(self, modeling_params, clf):
        self.modeling_params = modeling_params
        self.clf = clf

    def get_actual_params(self):
        ret = json.loads(json.dumps(self.modeling_params))
        algorithm = self.modeling_params['algorithm']

        logger.info("Clustering model inspector algo=%s" % algorithm)

        # Anything to do ?

        logger.info("End of get_actu_params: now %s" % ret)
        return {"resolved": ret}


def clustering_predict(modeling_params, clusterer, transformed_data):
    """Returns (labels np array, addtional columns DF)"""
    train = transformed_data["TRAIN"]
    train_np, is_sparse = prepare_multiframe(train, modeling_params)

    train_df = train.as_dataframe()
    for col in train_df:
        logger.info("F %s=%s" % (col, train_df[col].iloc[0]))
    if isinstance(clusterer, DkuIsolationForest):
        cluster_labels, anomaly_scores = clusterer.predict_with_anomaly_score(train_np)
    else:
        cluster_labels = clusterer.predict(train_np)
        anomaly_scores = None
    return cluster_labels, anomaly_scores


def clustering_fit(modeling_params, transformed_train):
    """
        Returns (clf, actual_params, cluster_labels)
    """
    train = transformed_train["TRAIN"]

    clf = clustering_model_from_params(modeling_params, len(train.index))
    # feed the column labels to the model
    set_column_labels(clf, train.columns())

    train_np, is_sparse = prepare_multiframe(train, modeling_params)
    initial_intrinsic_perf_data = get_initial_intrinsic_perf_data(train_np, is_sparse)

    train_df = train.as_dataframe()

    anomaly_scores = None
    if train_df.shape[0] > 0:
        for col in train_df:
            logger.info("FP %s=%s" % (col, train_df[col].iloc[0]))
        if isinstance(clf, DkuIsolationForest):
            # only for anomaly detection
            cluster_labels, anomaly_scores = clf.fit_predict_with_anomaly_score(train_np)
        elif 'fit_predict' in dir(clf):
            cluster_labels = clf.fit_predict(train_np)
        else:
            clf.fit(train_np)
            cluster_labels = clf.predict(train_np)
    else:
        logger.warning("Cannot fit clustering model: all rows have been dropped by preprocessing")
        cluster_labels = np.empty((0,))

    actual_params = ClusteringModelInspector(modeling_params, clf).get_actual_params()
    return (clf, actual_params, cluster_labels, anomaly_scores, initial_intrinsic_perf_data)
