# -*- coding: utf-8 -*-

import pickle as pkl

import dataiku
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split

from clv_forecast.auto_ml.custom_model.custom_clustering_model import (
    QuantileClusteringModel,
)
from clv_forecast.utils.clustering import apply_clustering
from dku_utils.projects.project_commons import get_current_project_and_variables
import numpy as np
project, variables = get_current_project_and_variables()

# Read recipe inputs
ml_data_filtered = dataiku.Dataset("ml_data_prepared")

max_ml_data_size = int(variables["standard"]["training_size_app"])

ml_data_filtered_df = ml_data_filtered.get_dataframe(
    sampling="random", limit=max_ml_data_size
)


n_cluster = int(variables["standard"]["number_value_cluster_app"])
notrain_flag = bool(variables["standard"]["notrain_app"])
folder_handle = dataiku.Folder("cluster_model")
clustering_method = variables["standard"]["value_clustering_method_app"]
extend_low = variables["standard"]["extend_low_app"]

if clustering_method == "manual":
    pass
else:
    if notrain_flag:
        pass
    else:
        if clustering_method == "quantiles":
            model = QuantileClusteringModel(
                target_column="average_monthly_value",
                n_clusters=n_cluster,
                column_labels=["average_monthly_value"],
                extend_low=extend_low,
            )
        elif clustering_method == "kmeans":
            model = KMeans(n_clusters=n_cluster)

        else:
            raise KeyError(
                f'{clustering_method} is an unknown clustering method; please use "quantiles" or "kmeans"'
            )
        model.fit(ml_data_filtered_df["average_monthly_value"].values.reshape(-1, 1))
        with folder_handle.get_writer("model.pkl") as w:
            pkl.dump(model, w)
            
if True:
    top_perc = np.percentile(ml_data_filtered_df["future_clv"].values, 99)
    ml_data_filtered_df["future_clv"] = np.clip(ml_data_filtered_df["future_clv"], a_min=0, a_max=top_perc)


ml_data_filtered_df["current_clv_cluster"] = apply_clustering(
    ml_data_filtered_df["average_monthly_value"]
)
ml_data_filtered_df["future_clv_cluster"] = apply_clustering(
    ml_data_filtered_df["future_average_monthly_value"]
)

ml_data_filtered_df["lifetime_recency"] = (
    ml_data_filtered_df["lifetime_recency"].fillna(-1).astype(int)
)

train_test_ratio = float(variables["standard"]["train_test_ratio_app"])
train, test = train_test_split(
    ml_data_filtered_df,
    train_size=train_test_ratio,
    random_state=42,
    stratify=ml_data_filtered_df["current_clv_cluster"].values,
)

# Write recipe outputs
ml_data_train = dataiku.Dataset("ml_data_train")
ml_data_train.write_with_schema(train)

ml_data_test = dataiku.Dataset("ml_data_test")
ml_data_test.write_with_schema(test)
