from dataiku.eda.types import Literal

from sklearn.isotonic import IsotonicRegression

from dataiku.eda.computations.immutable_data_frame import FloatVector
from dataiku.eda.curves.curve import Curve
from dataiku.eda.curves.curve import ParametrizedCurve
from dataiku.eda.types import IsotonicCurveModel, ParametrizedIsotonicCurveModel


class IsotonicCurve(Curve):
    @staticmethod
    def get_type() -> Literal["isotonic"]:
        return "isotonic"

    @staticmethod
    def build(params: IsotonicCurveModel) -> 'IsotonicCurve':
        return IsotonicCurve()

    def fit(self, x: FloatVector, y: FloatVector) -> 'ParametrizedIsotonicCurve':
        return ParametrizedIsotonicCurve(IsotonicRegression().fit(x, y))


class ParametrizedIsotonicCurve(ParametrizedCurve):
    def __init__(self, ir: IsotonicRegression):
        self.ir = ir

    def serialize(self) -> ParametrizedIsotonicCurveModel:
        return {
            "type": IsotonicCurve.get_type()
            # No parametrization (can be as large as the data in the worst case)
        }

    def apply(self, x: FloatVector) -> FloatVector:
        return self.ir.predict(x)
