from dataiku.core import doctor_constants as constants

from .pyfunc_common import (get_mlflow_model_params, get_configured_threshold, load_evaluation_dataset_sample)
from dataikuscoring.mlflow import (mlflow_raw_predict, mlflow_classification_predict_to_scoring_data,
                                   mlflow_regression_predict_to_scoring_data)

from dataiku.external_ml.predictor import ExternalModelPredictor, ExternalModelIndividualExplainer


class MLflowIndividualExplainer(ExternalModelIndividualExplainer):

    def _get_not_normalized_train_set(self):
        """For MLflow we only have access to the evaluation dataset."""
        if self.trainset is None:
            self.trainset = load_evaluation_dataset_sample(self._model_folder_context)
        return self.trainset

    def _get_train_set(self):
        return self._get_not_normalized_train_set()


class MLflowPredictor(ExternalModelPredictor):

    def __init__(self, model_folder_context):
        super(MLflowPredictor, self).__init__(model_folder_context)
        self._schema = None
        self.used_threshold = get_configured_threshold(self.params.model_folder_context)

    def get_individual_explainer_class(self):
        return MLflowIndividualExplainer

    def load_model(self):
        import mlflow
        with self.params.model_folder_context.get_folder_path_to_read() as model_folder_path:
            return mlflow.pyfunc.load_model(model_folder_path)

    def load_params(self, model_folder_context):
        return get_mlflow_model_params(model_folder_context)

    def predict(self, df, with_probas=True):
        # Load model when required
        if self.params.core_params["target_variable"] in df.columns:
            df = df.drop(self.params.core_params["target_variable"], axis=1)
        if self._prediction_type in [constants.BINARY_CLASSIFICATION, constants.MULTICLASS]:
            scoring_data = mlflow_classification_predict_to_scoring_data(self.model, self.params.model_meta, df, self.used_threshold)
        elif self._prediction_type == constants.REGRESSION:
            scoring_data = mlflow_regression_predict_to_scoring_data(self.model, self.params.model_meta, df)
        else:
            raise ValueError("prediction_type '{}' not suported.".format(self._prediction_type))
        return scoring_data.pred_and_proba_df

    def predict_raw(self, df, force_json_tensors_output=True):
        """When prediction type is not set we can still score using raw MLflow prediction"""
        return mlflow_raw_predict(self.model, self.params.model_meta, df, force_json_tensors_output=force_json_tensors_output)
