import logging

import pandas as pd

from dataiku.core import doctor_constants
from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.individual_explainer import DEFAULT_NB_EXPLANATIONS
from dataiku.doctor.posttraining.model_information_handler import build_model_handler
from dataiku.core.percentage_progress import PercentageProgress
from dataiku.doctor.posttraining.utils import is_evaluation_with_images
from dataiku.doctor.prediction.common import check_classical_prediction_type

logger = logging.getLogger(__name__)

DEFAULT_SAMPLE_SIZE = 1000


def compute(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, computation_params, fmi,
            postcompute_folder=None, train_split_desc=None, train_split_folder=None):
    _validate_params(computation_params, core_params, split_folder)

    model_handler = build_model_handler(split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                        postcompute_folder=postcompute_folder, train_split_desc=train_split_desc,
                                        train_split_folder=train_split_folder)

    progress = PercentageProgress(job_id)

    prediction_type = model_handler.get_prediction_type()
    check_classical_prediction_type(prediction_type)

    method = computation_params.get("method")
    debug_mode = computation_params.get("debug_mode", False)
    low_predictions_boundary = computation_params.get("low_predictions_boundary")
    high_predictions_boundary = computation_params.get("high_predictions_boundary")
    nb_explanations = computation_params.get("nb_explanations", DEFAULT_NB_EXPLANATIONS)
    random_state = computation_params.get("random_state", 1337)
    class_to_compute = computation_params.get("class_to_compute")
    if class_to_compute is None and prediction_type == doctor_constants.MULTICLASS:
        raise ValueError("In multiclass classification a class should be specified to compute the explanations")

    if model_handler.use_full_df():
        testset, _ = model_handler.get_full_df()
    else:
        testset, _ = model_handler.get_test_df()

    individual_explainer = model_handler.get_explainer()

    nb_records = min(computation_params.get("sample_size", DEFAULT_SAMPLE_SIZE),
                     testset.shape[0])
    if nb_records == 0:
        raise ValueError("Can not perform computation on an empty dataset")

    individual_explainer.make_ready()
    sample = testset.sample(n=nb_records, random_state=random_state)
    sample_pred_df = model_handler.get_predictor().predict(sample)
    observations_df, observations_pred_df = _sample_by_predictions(sample,
                                                                   sample_pred_df,
                                                                   prediction_type,
                                                                   low_predictions_boundary,
                                                                   high_predictions_boundary,
                                                                   class_to_compute)

    if observations_df.empty:
        explanations = pd.DataFrame(columns=observations_df.columns)
        predictions = []
    else:
        explanations = individual_explainer.explain(observations_df, nb_explanations, method,
                                                    for_class=class_to_compute,
                                                    debug_mode=debug_mode,
                                                    progress=progress)
        if prediction_type == doctor_constants.REGRESSION:
            predictions = observations_pred_df["prediction"].tolist()
        else:
            for_class = class_to_compute if prediction_type == doctor_constants.MULTICLASS else model_handler.get_inv_map()[1]
            predictions = observations_pred_df[u"proba_{}".format(safe_unicode_str(for_class))].tolist()

    results = {
        "explanations": explanations.to_dict(orient="list"),
        "observations": observations_df.fillna("").astype(str).to_dict(orient="list"),
        "predictions": predictions,
        "nbExplanations": nb_explanations,
        "nbRecords": nb_records,
        "onSample": nb_records < testset.shape[0],
        "randomState": random_state,
        "lowPredictionsBoundary": low_predictions_boundary,
        "highPredictionsBoundary": high_predictions_boundary,
        "method": method
    }
    posttrain_folder_context = model_handler.get_output_folder_context()
    explanations_file_name = "individual_explanations.json"
    if posttrain_folder_context.isfile(explanations_file_name):
        all_results = posttrain_folder_context.read_json(explanations_file_name)
    else:
        all_results = {"perClass": {}}
    all_results["perClass"][class_to_compute if class_to_compute is not None else "unique"] = results
    posttrain_folder_context.write_json(explanations_file_name, all_results)


def _validate_params(computation_params, core_params, split_folder):
    """
    :param computation_params: Computation-specific settings.
    :type computation_params: dict
    :param core_params: Model's core parameters
    :type core_params: dict
    :param split_folder: Path to the split folder
    :type split_folder: str
    :raises Exception: If the configuration is an Evaluation with Images or If Computation is invalid.
    """
    if is_evaluation_with_images(split_folder, core_params):
        # Preprocessing do not support having two image folders (reference and current)
        raise Exception("Computing Individual Explanations is not available for evaluations with images.")

    if computation_params is None or "method" not in computation_params \
            or "low_predictions_boundary" not in computation_params \
            or "high_predictions_boundary" not in computation_params:
        raise Exception(
            "'computation_params' should contains keys 'low_predictions_boundary', "
            "'high_predictions_boundary' and 'method'")


def _sample_by_predictions(df, pred_df, prediction_type, low_predictions_boundary, high_prediction_boundary, class_name=None):
    """ Return a sample of the given DataFrame using its predictions.

    If one prediction is between low_predictions_boundary and high_prediction_boundary its row is discarded
    :param pd.DataFrame df: DataFrame to sample
    :param pd.DataFrame pred_df: prediction DataFrame
    :param str prediction_type: the prediction type
    :param float low_predictions_boundary: the lower boundary
    :param float high_prediction_boundary: the upper boundary
    :param str or None class_name: the name of the class to be used for sampling by predictions. Mandatory in multiclass only.
    :return: the sample of the DataFrame
    :rtype: pd.DataFrame
    """
    if pred_df.empty:
        return RuntimeError("All rows have been dropped by the preprocessing")
    elif prediction_type == doctor_constants.REGRESSION:
        predictions = pred_df["prediction"]
        filtered_predictions = predictions[(predictions <= low_predictions_boundary) |
                                           (predictions >= high_prediction_boundary)]
    elif prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        probas_1 = pred_df.iloc[:, 2]
        filtered_predictions = pred_df[(probas_1 <= low_predictions_boundary) |
                                       (probas_1 >= high_prediction_boundary)]
    else:
        if class_name is None:
            raise ValueError("The class used to sample by prediction should be specified")
        probas_selected_class = pred_df[u"proba_{}".format(safe_unicode_str(class_name))]
        filtered_predictions = pred_df[(probas_selected_class <= low_predictions_boundary) |
                                       (probas_selected_class >= high_prediction_boundary)]
    return df.reindex(index=filtered_predictions.index), pred_df.reindex(index=filtered_predictions.index)
