import importlib

from dataiku.base.model_plugin import prepare_for_plugin
from dataiku.doctor.plugins.plugin_params import get_prediction_algo_params
from dataiku.doctor.prediction.common import ClassicalPredictionAlgorithm
from dataiku.doctor.prediction.common import TrainableModel
from dataiku.doctor.prediction.common import GridHyperparametersSpace
from dataiku.doctor.utils import doctor_constants


class PluginPredictionAlgorithm(ClassicalPredictionAlgorithm):
    algorithm = "CUSTOM_PLUGIN"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        prediction_type = core_params[doctor_constants.PREDICTION_TYPE]

        algo_info = modeling_params["plugin_python_grid"]
        plugin_id = algo_info["pluginId"]
        element_id = algo_info["elementId"]
        module = "dku-ml-plugins.{}.python-prediction-algos.{}.algo".format(plugin_id, element_id)
        prepare_for_plugin(plugin_id, element_id)
        algo_module = importlib.import_module(module)

        algo_params = get_prediction_algo_params(plugin_id, element_id)
        grid_params = [param["name"] for param in algo_params.get("params", []) if param.get("gridParam", False)]

        supports_sample_weights = algo_info.get("supportsSampleWeights", False)
        self.plugin_algo = algo_module.CustomPredictionAlgorithm(prediction_type=prediction_type,
                                                                 params=algo_info["params"])
        self.plugin_algo.set_grid_params(grid_params)

        estimator = self.plugin_algo.get_clf()
        grid = self.plugin_algo.get_grid() if hasattr(self.plugin_algo, "get_grid") else {}
        fit_params = self.plugin_algo.get_fit_params() if hasattr(self.plugin_algo, "get_fit_params") else {}

        return PluginPredictionTrainableModel(
            estimator,
            hyperparameters_space=GridHyperparametersSpace(grid),
            supports_sample_weights=supports_sample_weights,
            plugin_fit_parameters=fit_params
        )

    def actual_params(self, ret, clf, fit_params):
        if hasattr(self.plugin_algo, "get_best_clf_grid_params"):
            ret["plugin_python"] = self.plugin_algo.get_best_clf_grid_params(clf, fit_params)
        other = self.plugin_algo.get_other(ret, clf, fit_params) if hasattr(self.plugin_algo, "get_other") \
            else {}
        amp = {"resolved": ret, "other": other}
        return amp


class PluginPredictionTrainableModel(TrainableModel):

    def __init__(self, estimator, hyperparameters_space,
                 supports_sample_weights, plugin_fit_parameters=None):
        super(PluginPredictionTrainableModel, self).__init__(
            estimator,
            hyperparameters_space=hyperparameters_space,
            supports_sample_weights=supports_sample_weights
        )

        self._plugin_fit_parameters = plugin_fit_parameters or {}

    def get_fit_parameters(self, sample_weight=None, X_eval=None, y_eval=None,
                           is_final_fit=False):
        fit_parameters = super(PluginPredictionTrainableModel, self)\
            .get_fit_parameters(sample_weight, X_eval, y_eval, is_final_fit)

        fit_parameters.update(self._plugin_fit_parameters)
        return fit_parameters
