import logging

from dataiku.base.folder_context import build_folder_context
from dataiku.core import dkujson, doctor_constants
from dataiku.doctor.deephub.deephub_context import init_deephub_context
from dataiku.doctor.utils.gpu_execution import DeepHubGpuCapability, get_gpu_config_from_recipe_desc, get_gpu_config_from_core_params

PATH_COLUMN_FIELD_NAME = "pathColumn"

logger = logging.getLogger(__name__)


class DeepHubParams(object):

    def __init__(self, modeling_params, preprocessing_params, core_params, file_path_column, gpu_config, model_folder_context,
                 tmp_folder, batch_size):
        """
        :param modeling_params: modeling params, containing information about how the model and how it is trained
               (architecture, optimizer, augmentation params, ...)
               java counterpart: `com.dataiku.dip.analysis.model.prediction.DeepHubModelingParams`
        :param preprocessing_params: preprocessing parameters, mostly containing the target map
               java counterpart: `com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams`
        :param core_params: main information about the model (prediction type, file path col, etc...)
               java counterpart: `com.dataiku.dip.analysis.model.prediction.ResolvedDeepHubPredictionCoreParams`
        :param gpu_config: how to use GPUs for this training/scoring
        :param model_folder_context: folder containing all information about the model (model pickle, perf, report...)
        :param file_path_column: column in dataset that corresponds to path to file in managed folder
        :param tmp_folder: folder to be used if required, for storing temporary files for instance
        :param batch_size: batch size to be used for training/scoring. This is per device, i.e. if the training runs
                           on 2 GPUs, the actual batch size will be (2 * batch_size)
        """
        self.modeling_params = modeling_params
        self.preprocessing_params = preprocessing_params
        self.core_params = core_params
        self.file_path_column = file_path_column
        self.gpu_config = gpu_config
        self.model_folder_context = model_folder_context
        self.tmp_folder = tmp_folder
        self.batch_size = batch_size
        self.target_remapping = TargetRemapping(preprocessing_params.get("target_remapping", []))

    def init_deephub_context(self, allow_cpu_fallback=False):
        init_deephub_context(self.get_per_node_gpu_list(allow_cpu_fallback), self.tmp_folder)

    def get_per_node_gpu_list(self, allow_cpu_fallback):
        use_gpu = DeepHubGpuCapability.should_use_gpu(self.gpu_config, allow_cpu_fallback=allow_cpu_fallback)
        if not use_gpu:
            return []
        else:
            gpu_params = self.gpu_config["params"]
            return gpu_params.get("gpuList", [])

    def is_distributed(self):
        return False


class DeepHubScoringParams(DeepHubParams):

    def __init__(self, modeling_params, preprocessing_params, core_params, file_path_column, gpu_config, model_folder_context,
                 tmp_folder, batch_size, scoring_params):
        super(DeepHubScoringParams, self).__init__(modeling_params, preprocessing_params, core_params,
                                                   file_path_column, gpu_config, model_folder_context,
                                                   tmp_folder, batch_size)
        self.scoring_params = scoring_params

    @staticmethod
    def build_for_scoring_recipe(modeling_params, preprocessing_params, core_params, model_folder_context,
                                 tmp_folder, user_meta, recipe_desc):
        """
        Mostly relies on recipe_desc for runtime params (GPU, batch size, confidence threshold), else on params from the
        model (modeling, preprocessing, etc...)

        :rtype: DeepHubScoringParams
        """
        gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
        # todo @deephub: must be removed for multi gpu scoring
        DeepHubGpuCapability.force_single_gpu(gpu_config)  # For scoring only 1st requested gpu
        batch_size = recipe_desc["perDeviceBatchSize"]

        prediction_type = core_params["prediction_type"]
        scoring_params = {}
        if prediction_type == "DEEP_HUB_IMAGE_OBJECT_DETECTION":
            if recipe_desc["overrideModelSpecifiedConfidenceThreshold"]:
                confidence_threshold = recipe_desc.get("forcedConfidenceThreshold", 0.)
            else:
                confidence_threshold = user_meta.get("activeClassifierThreshold", 0.)
            scoring_params["confidence_threshold"] = confidence_threshold
        elif prediction_type != "DEEP_HUB_IMAGE_CLASSIFICATION":
            raise ValueError("Unknown prediction type '{}'".format(prediction_type))

        return DeepHubScoringParams(modeling_params, preprocessing_params, core_params,
                                    core_params[PATH_COLUMN_FIELD_NAME], gpu_config, model_folder_context, tmp_folder,
                                    batch_size, scoring_params)

    @staticmethod
    def build_for_predictor(modeling_params, preprocessing_params, core_params, model_folder_context, tmp_folder, user_meta):
        """
        Predictor params differ a bit from scoring recipe params:
        * The input data will be directly streamed (as base64) to the model through a generic column called "input",
          rather than using the pathColumn from the core params, therefore we overwrite the `file_path_column`
        * Deciding on whether to use GPU for scoring:
            * is taken from the original core params, i.e. the scoring will be on GPU if the training was on GPU
            * for now we do not support multi-GPUs scoring, so we enforce once GPU (the one with id 0)
        * batch size is also reusing the one from the training

        :rtype: DeepHubScoringParams
        """
        gpu_config = get_gpu_config_from_core_params(core_params)
        DeepHubGpuCapability.force_gpu_0(gpu_config)  # For scoring directly use gpu 0

        batch_size = modeling_params["perDeviceBatchSize"]

        prediction_type = core_params["prediction_type"]
        scoring_params = {}
        if prediction_type == "DEEP_HUB_IMAGE_OBJECT_DETECTION":
            scoring_params["confidence_threshold"] = user_meta.get("activeClassifierThreshold", 0.)
        elif prediction_type != "DEEP_HUB_IMAGE_CLASSIFICATION":
            raise ValueError("Unknown prediction type '{}'".format(prediction_type))

        return DeepHubScoringParams(modeling_params, preprocessing_params, core_params, "input", gpu_config,
                                    model_folder_context, tmp_folder, batch_size, scoring_params)


class DeepHubTrainingParams(DeepHubParams):

    def __init__(self, core_params, modeling_params, preprocessing_params, split_desc, preprocessing_folder,
                 model_folder_context, split_folder, tmp_folder):
        super(DeepHubTrainingParams, self).__init__(modeling_params,
                                                    preprocessing_params, core_params,
                                                    core_params[PATH_COLUMN_FIELD_NAME],
                                                    get_gpu_config_from_core_params(core_params),
                                                    model_folder_context, tmp_folder, modeling_params["perDeviceBatchSize"])
        self.managed_folder_id = core_params["managedFolderSmartId"]
        self.preprocessing_folder = preprocessing_folder
        self.split_folder = split_folder
        self.split_desc = split_desc

    def is_distributed(self):
        # Condition would be more complex with distribution over k8s
        return self.get_process_count_per_node() > 1

    def get_process_count_per_node(self):
        gpu_count = len(self.get_per_node_gpu_list(allow_cpu_fallback=True))
        if gpu_count > 0:
            return gpu_count
        return self.modeling_params["processCountPerNode"]

    def get_model_optimization_split_params(self):
        # in the event of a training launched with an old version of the params
        if "modelOptimizationSplitParams" not in self.modeling_params:
            params = {"trainSplitRatio": 0.8, "seed": 1337}
            logger.info("Model optimization split params are not set, using default values: {}".format(params))
        else:
            params = self.modeling_params["modelOptimizationSplitParams"]

        assert 0. < params["trainSplitRatio"] < 1.
        return params

    def to_str_params(self):
        return (dkujson.dumps(self.core_params), dkujson.dumps(self.modeling_params),
                dkujson.dumps(self.preprocessing_params),
                dkujson.dumps(self.split_desc), self.preprocessing_folder, self.model_folder_context.get_origin_folder_path(), self.split_folder,
                self.tmp_folder)

    @classmethod
    def from_str_params(cls, core_params_str, modeling_params_str, preprocessing_params_str, split_desc_str,
                        preprocessing_folder, model_folder_context_origin_folder, split_folder, tmp_folder):
        core_params = dkujson.loads(core_params_str)
        modeling_params = dkujson.loads(modeling_params_str)
        preprocessing_params = dkujson.loads(preprocessing_params_str)
        split_desc = dkujson.loads(split_desc_str)
        model_folder_context = build_folder_context(model_folder_context_origin_folder)
        return cls(core_params, modeling_params, preprocessing_params, split_desc,
                   preprocessing_folder, model_folder_context, split_folder, tmp_folder)


class TargetRemapping(object):

    def __init__(self, target_remapping_list):
        """
        Simple utility to manipulate the target remapping from the preprocessing.

        The format of the raw list is the following:
            [{ "sourceValue": "category0", "mappedValue": 0, "sampleFreq": 1488}, ...]

        The goal of this object is to replace string values by a mapped integer so that models can handle them more
        gracefully.

        In the context of deephub, we only use this object to:
         * list all "valid" categories, i.e. the ones considered by the model
         * order each category and to go from `category` to `category_index` easily, where `category_index` is the
           index of the `category` in the initial target_remapping.

        Then, in each model, the category is represented by its index.

        This means that we never use the "mappedValue".
        """
        self._target_remapping_list = target_remapping_list
        self._categories = [e["sourceValue"] for e in self._target_remapping_list]
        self._category_to_index_map = {
            category: index for (index, category) in enumerate(self._categories)
        }

    def __len__(self):
        return len(self._target_remapping_list)

    def get_category(self, index):
        return self._categories[index]

    def get_category_index(self, category):
        return self._category_to_index_map[category]

    def get_category_to_index_map(self):
        return self._category_to_index_map

    def list_categories(self):
        return self._categories

    def get_index_to_category_map(self):
        return {
            class_id: label
            for class_id, label in enumerate(self._categories)
        }
