from __future__ import unicode_literals

from abc import ABCMeta
from abc import abstractmethod
from enum import Enum

import numpy as np
from six import add_metaclass

# Important reminder to keep these data structures and values in sync with the ones defined on the Java side on the
# MLOverridesParamsBase.java and on the OverrideInfo.java

OVERRIDES_FILE_NAME = "roverrides.json"
OVERRIDE_INFO_COL = "override"


@add_metaclass(ABCMeta)
class AbstractMLOutcome(object):
    class Type(Enum):
        INTERVAL = 0
        CATEGORY = 1
        DECLINED = 2

    def __init__(self, type_):
        """
        :param AbstractMLOutcome.Type type_:
        """
        self.type = type_

    @abstractmethod
    def apply(self, predictions):
        """
        :type predictions: np.ndarray
        :return: an array with the same shape as the predictions
        :rtype: np.ndarray
        """
        pass

    @classmethod
    def from_dict(cls, obj):
        """
        :param dict obj:
        :rtype: Outcome
        """
        assert isinstance(obj, dict)
        cls.type = cls.Type[obj.get("type")]
        if cls.type == cls.Type.CATEGORY:
            return CategoryOutcome(str(obj.get("category")))
        elif cls.type == cls.Type.INTERVAL:
            return IntervalOutcome(float(obj.get("minValue")), float(obj.get("maxValue")))
        else:
            return DeclinedOutcome()

    def to_dict(self):
        """
        :rtype: dict
        """
        result = {
            "type": self.type.name
        }
        return result


class CategoryOutcome(AbstractMLOutcome):
    def __init__(self, category):
        """
        :param str category: class to enforce
        """
        super(CategoryOutcome, self).__init__(self.Type.CATEGORY)
        self.category = category

    def apply(self, predictions):
        """
        :type predictions: np.ndarray
        :rtype: np.ndarray
        """
        return np.full(predictions.shape, self.category)

    def to_dict(self):
        """
        :rtype: dict
        """
        result = super(CategoryOutcome, self).to_dict()
        result["category"] = self.category
        return result


class IntervalOutcome(AbstractMLOutcome):

    def __init__(self, min_value, max_value):
        """
        :param float min_value: min float value to clip prediction
        :param float max_value: max float value to clip prediction
        """
        super(IntervalOutcome, self).__init__(self.Type.INTERVAL)
        self.min_value = min_value
        self.max_value = max_value

    def apply(self, predictions):
        """
        :param np.ndarray predictions:
        :rtype: np.ndarray
        """
        return np.clip(predictions, self.min_value, self.max_value)

    def to_dict(self):
        """
        :rtype: dict
        """
        result = super(IntervalOutcome, self).to_dict()
        result["minValue"] = self.min_value
        result["maxValue"] = self.max_value
        return result


class DeclinedOutcome(AbstractMLOutcome):

    def __init__(self):
        super(DeclinedOutcome, self).__init__(self.Type.INTERVAL)

    def apply(self, predictions):
        """
        :param np.ndarray predictions:
        :rtype: np.ndarray
        """
        return np.full(predictions.shape, np.nan)


class Override(object):
    def __init__(self, name, filter_, outcome):
        """
        :param str name: Identifier of the override
        :param str filter_: FilterDesc structure as string.
                            It will simply be sent to the backend to generate the flag steps
        :param AbstractMLOutcome outcome: Outcome that will be applied to each matching row.
        """
        self.name = name
        self.filter = filter_
        self.outcome = outcome

    @staticmethod
    def from_dict(obj):
        """
        :param dict obj:
        :rtype: Override
        """
        name = obj.get("name")
        filter_ = obj.get("filter")
        outcome = AbstractMLOutcome.from_dict(obj.get("outcome"))
        return Override(name, filter_, outcome)

    def to_dict(self):
        """
        :rtype: dict
        """
        result = {
            "name": self.name,
            "filter": self.filter,
            "outcome": self.outcome.to_dict()
        }
        return result


class MlOverridesParams(object):
    def __init__(self, overrides):
        """
        :param list[Override] overrides: given the fact that this class is built with MlOverridesParams::from_dict,
        we know that overrides will never be empty.
        """
        self.overrides = overrides

    @staticmethod
    def from_dict(obj):
        """
        :param dict obj:
        """
        if "overrides" not in obj or obj["overrides"] is None or len(obj["overrides"]) == 0:
            return None
        overrides = [Override.from_dict(c) for c in obj.get("overrides")]
        return MlOverridesParams(overrides)

    def to_dict(self):
        """
        :rtype: dict
        """
        result = {
            "overrides": [c.to_dict() for c in self.overrides]
        }
        return result


def ml_overrides_params_from_model_folder(model_folder_context):
    """
    :param dataiku.base.folder_context.FolderContext model_folder_context:
    :rtype: MlOverridesParams
    """
    if not model_folder_context.isfile(OVERRIDES_FILE_NAME):
        return None
    overrides_json = model_folder_context.read_json(OVERRIDES_FILE_NAME)
    return MlOverridesParams.from_dict(overrides_json)


