import dataiku

from clv_forecast.constants import (
    CLUSTER_CLASSIFICATION_MODEL_ANALYSIS_ID,
    CLUSTER_CLASSIFICATION_MODEL_ML_TASK_ID,
    CLUSTER_CLASSIFICATION_TRAINING_RECIPE_NAME,
    CLUSTER_CLASSIFICATION_DEPLOYED_MODEL_ID,
)
from clv_forecast.constants import (
    LIFETIME_MODEL_ANALYSIS_ID,
    LIFETIME_MODEL_ML_TASK_ID,
    LIFETIME_TRAINING_RECIPE_NAME,
    LIFETIME_DEPLOYED_MODEL_ID,
)
from clv_forecast.constants import (
    VALUE_REGRESSION_MODEL_ANALYSIS_ID,
    VALUE_REGRESSION_MODEL_ML_TASK_ID,
    VALUE_REGRESSION_TRAINING_RECIPE_NAME,
    VALUE_REGRESSION_DEPLOYED_MODEL_ID,
)
from clv_forecast.utils.feature_processing import update_processing_features
from dku_utils.projects.project_commons import get_current_project_and_variables
from dku_utils.projects.visual_ml.visual_ml_commons import (
    get_ml_task_and_settings,
    train_models_then_deploy_best_from_last_session,
)
from dku_utils.projects.visual_ml.visual_ml_commons import (
    remove_unavailable_features_from_ml_task_features_handling,
    reject_features_from_ml_task_settings,
)


def train_lifetime_model():
    project, variables = get_current_project_and_variables()

    ml_task, ml_task_settings = get_ml_task_and_settings(
        project, LIFETIME_MODEL_ANALYSIS_ID, LIFETIME_MODEL_ML_TASK_ID,
    )

    ml_task_settings.get_raw()["splitParams"]["instanceIdRefresher"] += 1
    ml_task_settings.save()

    remove_unavailable_features_from_ml_task_features_handling(
        project, ml_task_settings, "ml_data_train"
    )

    previous_per_feature_handling = ml_task_settings.get_raw()["preprocessing"][
        "per_feature"
    ]
    previously_handled_features = list(previous_per_feature_handling.keys())

    train_dataset = dataiku.Dataset("ml_data_train")

    schema = train_dataset.read_schema()
    lifetime_columns = [
        "customer_age",
        "lifetime_recency",
        "lifetime_frequency",
        "average_transaction_value",
        "future_clv",
    ]

    to_remove_columns = list(
        filter(
            lambda name: not (name in lifetime_columns)
            and (name in previously_handled_features),
            map(lambda pair: pair["name"], schema),
        )
    )
    reject_features_from_ml_task_settings(ml_task_settings, to_remove_columns)

    train_models_then_deploy_best_from_last_session(
        ml_task,
        metric_name="MAPE",
        bool_greater_metric_is_better=False,
        bool_compute_sub_population_analysis=True,
        sub_population_analysis_variables=[],
        sub_population_analysis_test_set_fraction_to_sample=1.0,
        bool_compute_partial_dependencies=False,
        partial_dependencies_variables=[],
        partial_dependencies_test_set_fraction_to_sample=1.0,
        flow_training_recipe_name=LIFETIME_TRAINING_RECIPE_NAME,
        flow_saved_model_id=LIFETIME_DEPLOYED_MODEL_ID,
    )


def train_regression_model():

    project, variables = get_current_project_and_variables()

    ml_task, ml_task_settings = get_ml_task_and_settings(
        project, VALUE_REGRESSION_MODEL_ANALYSIS_ID, VALUE_REGRESSION_MODEL_ML_TASK_ID,
    )

    ml_task_settings.get_raw()["splitParams"]["instanceIdRefresher"] += 1
    ml_task_settings.save()

    update_processing_features(
        project,
        variables,
        ml_task,
        ml_task_settings,
        train_dataset_name="ml_data_train_with_lifetime_prepared",
    )

    ml_task, ml_task_settings = get_ml_task_and_settings(
        project, VALUE_REGRESSION_MODEL_ANALYSIS_ID, VALUE_REGRESSION_MODEL_ML_TASK_ID,
    )
    preprocessing = ml_task_settings.mltask_settings["preprocessing"]["per_feature"]

    columns_analytics = []
    for param in preprocessing.keys():
        if preprocessing[param]["role"] == "INPUT":
            columns_analytics += [param]

    if variables["standard"]["leverage_customer_metadata_app"]:
        rfm_dataset = dataiku.Dataset(
            variables["standard"]["customer_metadata_dataset_app"]
        )

        schema = rfm_dataset.read_schema()
        columns = list(
            filter(
                lambda name: name != "customer_id",
                map(lambda pair: pair["name"], schema),
            )
        )
        columns_analytics += columns

    train_models_then_deploy_best_from_last_session(
        ml_task,
        metric_name="MAPE",
        bool_greater_metric_is_better=False,
        bool_compute_sub_population_analysis=True,
        sub_population_analysis_variables=columns_analytics,
        sub_population_analysis_test_set_fraction_to_sample=1.0,
        bool_compute_partial_dependencies=True,
        partial_dependencies_variables=columns_analytics,
        partial_dependencies_test_set_fraction_to_sample=1.0,
        flow_training_recipe_name=VALUE_REGRESSION_TRAINING_RECIPE_NAME,
        flow_saved_model_id=VALUE_REGRESSION_DEPLOYED_MODEL_ID,
    )


def train_classification_model():
    project, variables = get_current_project_and_variables()

    ml_task, ml_task_settings = get_ml_task_and_settings(
        project,
        CLUSTER_CLASSIFICATION_MODEL_ANALYSIS_ID,
        CLUSTER_CLASSIFICATION_MODEL_ML_TASK_ID,
    )
    ml_task_settings.get_raw()["splitParams"]["instanceIdRefresher"] += 1
    ml_task_settings.save()

    old_setting = ml_task_settings.mltask_settings
    # preprocessing = ml_task_settings.mltask_settings["preprocessing"]["per_feature"]

    # update task but keep the previous preprocessing
    # to adapt to new number of category for classification
    ml_task.guess()
    ml_task, ml_task_settings = get_ml_task_and_settings(
        project,
        CLUSTER_CLASSIFICATION_MODEL_ANALYSIS_ID,
        CLUSTER_CLASSIFICATION_MODEL_ML_TASK_ID,
    )
    old_setting["preprocessing"]["target_remapping"] = ml_task_settings.mltask_settings[
        "preprocessing"
    ]["target_remapping"]
    # ml_task_settings.mltask_settings["preprocessing"]["per_feature"] = preprocessing
    ml_task_settings.mltask_settings = old_setting
    ml_task_settings.save()

    update_processing_features(
        project,
        variables,
        ml_task,
        ml_task_settings,
        train_dataset_name="ml_data_train_with_lifetime_prepared",
    )

    ml_task, ml_task_settings = get_ml_task_and_settings(
        project,
        CLUSTER_CLASSIFICATION_MODEL_ANALYSIS_ID,
        CLUSTER_CLASSIFICATION_MODEL_ML_TASK_ID,
    )
    preprocessing = ml_task_settings.mltask_settings["preprocessing"]["per_feature"]
    columns_analytics = []
    for param in preprocessing.keys():
        if preprocessing[param]["role"] == "INPUT":
            columns_analytics += [param]

    if variables["standard"]["leverage_customer_metadata_app"]:
        rfm_dataset = dataiku.Dataset(
            variables["standard"]["customer_metadata_dataset_app"]
        )

        schema = rfm_dataset.read_schema()
        columns = list(
            filter(
                lambda name: name != "customer_id",
                map(lambda pair: pair["name"], schema),
            )
        )
        columns_analytics += columns

    train_models_then_deploy_best_from_last_session(
        ml_task,
        metric_name="F1",
        bool_greater_metric_is_better=True,
        bool_compute_sub_population_analysis=False,
        sub_population_analysis_variables=columns_analytics,
        sub_population_analysis_test_set_fraction_to_sample=1.0,
        bool_compute_partial_dependencies=True,
        partial_dependencies_variables=columns_analytics,
        partial_dependencies_test_set_fraction_to_sample=1.0,
        flow_training_recipe_name=CLUSTER_CLASSIFICATION_TRAINING_RECIPE_NAME,
        flow_saved_model_id=CLUSTER_CLASSIFICATION_DEPLOYED_MODEL_ID,
    )
