import json
import pandas as pd
import logging

from dataiku.external_ml.proxy_model.common.outputformat import BaseReader

logger = logging.getLogger(__name__)


class SagemakerJSONReader(BaseReader):
    NAME = "OUTPUT_SAGEMAKER_JSON"

    def __init__(self, prediction_type, value_to_class):
        super(SagemakerJSONReader, self).__init__(prediction_type, value_to_class)
        self.prediction_property = None
        self.predictions_property = "predictions"

    def init_prediction_property(self, predictions):
        if self.prediction_property:
            return
        if self.prediction_type in ["MULTICLASS", "BINARY_CLASSIFICATION"]:
            self.prediction_property = "predicted_label"
        elif self.prediction_type == "REGRESSION":
            if "predicted_label" in predictions[0]:
                self.prediction_property = "predicted_label"
            elif "score" in predictions[0]:
                self.prediction_property = "score"
            else:
                self.prediction_property = next(iter(predictions[0].keys()))
        logger.info("Getting prediction using property {} of entries of {}".format(self.prediction_property, self.predictions_property))

    def can_read(self, endpoint_output):
        try:
            predictions = self.parse_endpoint_output(endpoint_output)
            logger.info("Predictions are in JSON format")
            self.init_prediction_property(predictions)
            return True
        except json.JSONDecodeError as json_exception:
            logger.info("Predictions are not in JSON format")
            logger.debug("JSON Parse exception: {}".format(json_exception))
            return False

    def read_binary(self, endpoint_output):
        predictions = self.parse_endpoint_output(endpoint_output)
        return pd.DataFrame(self.read_binary_parsed(predictions))

    def read_multiclass(self, endpoint_output):
        predictions = self.parse_endpoint_output(endpoint_output)
        return pd.DataFrame(self.read_multiclass_parsed(predictions))

    def read_regression(self, endpoint_output):
        predictions = self.parse_endpoint_output(endpoint_output)
        return pd.DataFrame(self.read_regression_parsed(predictions))

    def read_binary_parsed(self, parsed_endpoint_output):
        self.init_prediction_property(parsed_endpoint_output)

        results = []
        for prediction in parsed_endpoint_output:
            # FIXME: should we not use the value to class here too?
            result = {"prediction": prediction[self.prediction_property]}  # fail if the prediction property is missing
            proba_1 = prediction.get("score")
            if proba_1:
                result["proba_{}".format(self.value_to_class[1])] = proba_1
                result["proba_{}".format(self.value_to_class[0])] = 1 - proba_1
            results.append(result)
        return results

    def read_multiclass_parsed(self, parsed_endpoint_output):
        self.init_prediction_property(parsed_endpoint_output)

        results = []
        found_non_iterable_probas = False
        for prediction in parsed_endpoint_output:
            label = prediction.get(self.prediction_property)
            if isinstance(label, float):
                label = int(label)
            result = {"prediction": label}
            probas = prediction.get("score")
            if probas:
                try:
                    for i, proba in enumerate(probas):
                        result["proba_{}".format(self.value_to_class[i])] = proba
                except TypeError:
                    # simply save that there was a problem.
                    # we will log it later, since logging it here could generate
                    # a huge amount of logs.
                    found_non_iterable_probas = True
                    pass
            results.append(result)
        if found_non_iterable_probas:
            logger.warning("Found non-iterable probas.")
        return results

    def read_regression_parsed(self, parsed_endpoint_output):
        self.init_prediction_property(parsed_endpoint_output)

        results = []
        for prediction in parsed_endpoint_output:
            results.append({"prediction": prediction.get(self.prediction_property)})
        return results

    def parse_endpoint_output(self, endpoint_output):
        from_json = json.loads(endpoint_output)
        return from_json.get(self.predictions_property)
