import logging

from dataiku.doctor.prediction.classification_fit import classification_fit
from dataiku.doctor.prediction.classification_scoring import binary_classification_scorer_with_valid
from dataiku.doctor.prediction.classification_scoring import multiclass_scorer_with_valid
from dataiku.doctor.prediction.prediction_interval_model import train_prediction_interval_or_none
from dataiku.doctor.prediction.regression_fit import regression_fit_single
from dataiku.doctor.prediction.regression_scoring import regression_scorer_with_valid
from dataiku.doctor.prediction.scorable_model import ScorableModel
from dataiku.doctor.utils import doctor_constants

logger = logging.getLogger(__name__)


def minimal_train_get_train_test_score(transformed_train,
                                       transformed_test,
                                       train_df_index,
                                       test_df_index,
                                       core_params,
                                       modeling_params,
                                       model_folder_context,
                                       target_map,
                                       preprocessing_params,
                                       ml_overrides_params
                                       ):
    prediction_type = core_params["prediction_type"]
    model_type = core_params["taskType"]

    weight_method = core_params.get("weight", {}).get("weightMethod", None)
    with_sample_weight = weight_method in {"SAMPLE_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    with_class_weight = weight_method in {"CLASS_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    calibration_method = core_params.get("calibration", {}).get("calibrationMethod", None)

    if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
        clf, _, _, _ = classification_fit(
            modeling_params,
            core_params,
            transformed_train,
            model_folder_context=None,
            gridsearch_done_fn=None,
            target_map=target_map,
            with_sample_weight=with_sample_weight,
            with_class_weight=with_class_weight,
            calibration_method=calibration_method)

        model = ScorableModel.build(clf, model_type, prediction_type, modeling_params['algorithm'],
                                    preprocessing_params, ml_overrides_params)

        if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
            scorer_test = binary_classification_scorer_with_valid(modeling_params, model,
                                                                  transformed_test, model_folder_context, test_df_index, target_map=target_map,
                                                                  with_sample_weight=with_sample_weight)
            scorer_train = binary_classification_scorer_with_valid(modeling_params, model,
                                                                   transformed_train, model_folder_context, train_df_index, target_map=target_map,
                                                                   with_sample_weight=with_sample_weight)
        else:
            scorer_test = multiclass_scorer_with_valid(modeling_params, model,
                                                       transformed_test, model_folder_context, test_df_index, target_map=target_map, with_sample_weight=with_sample_weight)
            scorer_train = multiclass_scorer_with_valid(modeling_params, model,
                                                        transformed_train, model_folder_context, train_df_index, target_map=target_map, with_sample_weight=with_sample_weight)

    elif prediction_type == doctor_constants.REGRESSION:
        clf, _, _, _ = regression_fit_single(
            modeling_params,
            core_params,
            transformed_train,
            model_folder_context=None,
            gridsearch_done_fn=None,
            with_sample_weight=with_sample_weight)

        pred_interval_model = train_prediction_interval_or_none(clf, core_params, modeling_params, transformed_test)
        model = ScorableModel.build(clf, model_type, prediction_type, modeling_params["algorithm"],
                                    overrides_params=ml_overrides_params, prediction_interval_model=pred_interval_model)

        scorer_test = regression_scorer_with_valid(modeling_params, model, transformed_test, model_folder_context, test_df_index, with_sample_weight)
        scorer_train = regression_scorer_with_valid(modeling_params, model, transformed_train, model_folder_context, train_df_index, with_sample_weight)

    else:
        raise ValueError("Prediction type %s is not valid" % prediction_type)

    scorer_test.score()
    scorer_train.score()
    return {
        "train_metrics": scorer_train.perf_data,
        "test_metrics": scorer_test.perf_data,
    }
