from dataiku.core import doctor_constants


class TreatmentMap(object):

    def __init__(self, control_value, treatment_values, drop_missing_treatment_value):
        self.control_value = control_value
        self.treatment_values = treatment_values if not drop_missing_treatment_value else [t for t in treatment_values if len(t) > 0]
        self.mapping = {x: i + 1 for i, x in enumerate(sorted([x for x in self.treatment_values if x != control_value]))}
        self.mapping[control_value] = 0

    def items_except_control(self):
        for k, v in self.mapping.items():
            if k == self.control_value:
                continue
            yield k, v

    def items(self):
        return self.mapping.items()


def check_causal_prediction_type(prediction_type):
    if prediction_type not in doctor_constants.CAUSAL_PREDICTION_TYPES:
        raise ValueError("Unsupported prediction type: " + prediction_type)
