import dataiku
import pandas as pd
import pickle as pkl

from clv_forecast.utils.generic import reformat_group
from dku_utils.projects.project_commons import get_current_project_and_variables


def assign_manual_cluster(value, cluster_df: pd.DataFrame):
    if value < cluster_df["threshold_value"].min():
        return 0
    else:
        return cluster_df[cluster_df["threshold_value"] <= value].tail(1).index[0] + 1


def apply_clustering(input_column: pd.Series):
    project, variables = get_current_project_and_variables()
    clustering_method = variables["standard"]["value_clustering_method_app"]
    if clustering_method == "manual":
        customer_clusters = dataiku.Dataset("customer_clusters")
        customer_clusters_df = customer_clusters.get_dataframe().sort_values(
            by=["threshold_value"]
        )
        output_column = input_column.apply(
            lambda value: assign_manual_cluster(value, customer_clusters_df)
        )
    else:

        folder_handle = dataiku.Folder("cluster_model")
        with folder_handle.get_download_stream("model.pkl") as reader:
            model = pkl.load(reader)
        output_column = model.predict(input_column.values.reshape(-1, 1))
    output_column = list(map(reformat_group, output_column))

    return output_column
