import logging

from dataiku.core import doctor_constants
from dataiku.doctor.prediction.common import ClassicalPredictionAlgorithm
from dataiku.doctor.prediction.common import FloatHyperparameterDimension
from dataiku.doctor.prediction.common import HyperparametersSpace
from dataiku.doctor.prediction.common import IntegerHyperparameterDimension
from dataiku.doctor.prediction.common import TrainableModel
from dataiku.doctor.utils.gpu_execution import DeepNNGpuCapability, get_gpu_config_from_core_params

logger = logging.getLogger(__name__)


class _DeepNeuralNetworkPredictionAlgorithm(ClassicalPredictionAlgorithm):
    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = self._create_hp_space(input_hp_space)
        gpu_config = get_gpu_config_from_core_params(core_params)

        estimator = self._create_estimator(input_hp_space, gpu_config)
        return TrainableModel(
            estimator,
            hyperparameters_space=hp_space,
            supports_sample_weights=False
        )

    def _create_estimator(self, input_hp_space, gpu_config):
        raise NotImplementedError("must be implemented in subclasses")

    @staticmethod
    def _create_hp_space(input_hp_space):
        return HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension={
                "lr": FloatHyperparameterDimension(input_hp_space["learning_rate"]),
            },
            hp_names_to_dimension_class={
                "hidden_layers": IntegerHyperparameterDimension,
                "units": IntegerHyperparameterDimension,
            }
        )

    def actual_params(self, output, estimator, fit_params):
        model_params = estimator.get_params()

        output["deep_neural_network"] = {
            "learning_rate": model_params["lr"],
            "hidden_layers": model_params["hidden_layers"],
            "units": model_params["units"],
            "max_epochs": model_params["max_epochs"],
            "batch_size": model_params["batch_size"],
            "epochs": model_params["callbacks__EpochCounter__epochs"],
            "dropout": model_params["module__dropout"],
            "reg_l2": model_params["optimizer__weight_decay"],
            "reg_l1": model_params["reg_l1"],
        }
        if "callbacks__EarlyStopping" in model_params:
            output["deep_neural_network"]["early_stopping_enabled"] = True
            output["deep_neural_network"]["early_stopping_patience"] = model_params["callbacks__EarlyStopping__patience"]
            output["deep_neural_network"]["early_stopping_threshold"] = model_params["callbacks__EarlyStopping__threshold"]
        else:
            output["deep_neural_network"]["early_stopping_enabled"] = False

        return {
            "resolved": output,
            "other": {},
        }


class DeepNeuralNetworkRegression(_DeepNeuralNetworkPredictionAlgorithm):
    algorithm = "DEEP_NEURAL_NETWORK_REGRESSION"

    def _create_estimator(self, input_hp_space, gpu_config):
        from dataiku.doctor.prediction.deep_neural_network_model import DKUNeuralNetRegressor

        max_epochs = input_hp_space["max_epochs"]
        batch_size = input_hp_space["batch_size"]
        device = DeepNNGpuCapability.get_device(gpu_config)
        early_stopping_enabled = input_hp_space["early_stopping_enabled"]
        early_stopping_patience = input_hp_space["early_stopping_patience"]
        early_stopping_threshold = input_hp_space["early_stopping_threshold"]
        dropout = input_hp_space["dropout"]
        reg_l2 = input_hp_space["reg_l2"]
        reg_l1 = input_hp_space["reg_l1"]

        return DKUNeuralNetRegressor(
            max_epochs=max_epochs,
            batch_size=batch_size,
            device=device,
            early_stopping_enabled=early_stopping_enabled,
            early_stopping_patience=early_stopping_patience,
            early_stopping_threshold=early_stopping_threshold,
            dropout=dropout,
            reg_l2=reg_l2,
            reg_l1=reg_l1,
        )


class DeepNeuralNetworkClassification(_DeepNeuralNetworkPredictionAlgorithm):
    algorithm = "DEEP_NEURAL_NETWORK_CLASSIFICATION"

    def _create_estimator(self, input_hp_space, gpu_config):
        from dataiku.doctor.prediction.deep_neural_network_model import DKUNeuralNetClassifier

        max_epochs = input_hp_space["max_epochs"]
        batch_size = input_hp_space["batch_size"]
        device = DeepNNGpuCapability.get_device(gpu_config)
        early_stopping_enabled = input_hp_space["early_stopping_enabled"]
        early_stopping_patience = input_hp_space["early_stopping_patience"]
        early_stopping_threshold = input_hp_space["early_stopping_threshold"]
        dropout = input_hp_space["dropout"]
        reg_l2 = input_hp_space["reg_l2"]
        reg_l1 = input_hp_space["reg_l1"]

        return DKUNeuralNetClassifier(
            max_epochs=max_epochs,
            batch_size=batch_size,
            device=device,
            early_stopping_enabled=early_stopping_enabled,
            early_stopping_patience=early_stopping_patience,
            early_stopping_threshold=early_stopping_threshold,
            dropout=dropout,
            reg_l2=reg_l2,
            reg_l1=reg_l1,
        )
