# coding: utf-8
from __future__ import unicode_literals

import json
import logging
import sys
import traceback
from abc import abstractmethod
from abc import ABCMeta

import pandas as pd
import six

from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import get_json_friendly_error
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging
from dataiku.core import doctor_constants as constants
from dataiku.doctor.exploration.exploration import DoctorEmuWrapper
from dataiku.doctor.exploration.exploration import CounterfactualsStrategy
from dataiku.doctor.exploration.exploration import OutcomeOptimizationStrategy
from dataiku.doctor.exploration.emu.generators import SpecialTarget
from dataiku.doctor.posttraining.model_information_handler import build_model_handler
from dataiku.doctor.utils import add_missing_columns
from dataiku.doctor.utils import dataframe_from_dict_with_dtypes
from dataiku.doctor.utils import ml_dtypes_from_dss_schema

logger = logging.getLogger(__name__)


class InteractiveModelProtocol(object):
    def __init__(self, link):
        self.link = link

    def _send_results(self, results):
        self.link.send_json({"results": results})

    def _handle_compute_score(self, interactive_scorer, records):
        """
        :type interactive_scorer: AbstractInteractiveScorer
        :param records: records as [{"feature1": values, "feature2: value}, ...]
        :type records: list[dict[str, object]]
        """
        pred_df = interactive_scorer.score(records)
        index_before_pp = range(len(records))
        self._send_results(dataframes_to_list({"score": pred_df}, index_before_pp))

    def _handle_compute_explanation(self, interactive_scorer, computation_params, records):
        """
        :type interactive_scorer: AbstractInteractiveScorer
        :type computation_params: dict
        :param records: records as [{"feature1": values, "feature2: value}, ...]
        :type records: list[dict[str, object]]
        """
        prediction_df, explanations_df = interactive_scorer.explain(computation_params, records)
        index_before_pp = range(len(records))
        self._send_results(dataframes_to_list({"score": prediction_df, "explanation": explanations_df},
                                              index_before_pp))

    def _handle_compute_counterfactuals(self, interactive_scorer, computation_params, records):
        """
        :type interactive_scorer: AbstractInteractiveScorer
        :type computation_params: dict
        :param records: records as [{"feature1": values, "feature2: value}, ...]
        :type records: list[dict[str, object]]
        """
        results = interactive_scorer.compute_counterfactuals(computation_params, records)
        self._send_results(results)

    def _handle_compute_outcome_optimization(self, interactive_scorer, computation_params, records):
        """
        :type interactive_scorer: AbstractInteractiveScorer
        :type computation_params: dict
        :param records: records as [{"feature1": values, "feature2: value}, ...]
        :type records: list[dict[str, object]]
        """
        results = interactive_scorer.optimize_outcome(computation_params, records)
        self._send_results(results)

    def _handle_command_exception(self, e):
        traceback.print_exc()
        traceback.print_stack()
        logger.error(e)
        error = get_json_friendly_error()
        self.link.send_json({'error': error})

    def start(self):
        interactive_scorer = None
        while True:
            try:
                command = self.link.read_json()
                params = json.loads(command["params"])

                if interactive_scorer is None:
                    interactive_scorer = InteractiveScorer(params)

                if command["type"] == "SCORING":
                    self._handle_compute_score(interactive_scorer, params["records"])
                elif command["type"] == "EXPLANATIONS":
                    self._handle_compute_explanation(interactive_scorer, params["computation_params"], params["records"])
                elif command["type"] == "COUNTERFACTUALS":
                    self._handle_compute_counterfactuals(interactive_scorer, params["computation_params"], params["records"])
                elif command["type"] == "OUTCOME_OPTIMIZATION":
                    self._handle_compute_outcome_optimization(interactive_scorer, params["computation_params"], params["records"])
                else:
                    logging.info("Interactive Scoring - Command %s not recognized" % command["type"])
            except Exception as e:
                self._handle_command_exception(e)


def dataframes_to_list(df_dict, original_index):
    """ Convert dataframe to list of dict or None.
    The returned list contains, for each row:
    - None if an index is in original_index but not in df
    - the dict of the dataframe's row otherwise
    Note that all NaN values are removed and will not be in the resulting dict

    Example:
        calling the function on:
        df_dict = {
            "df1": pd.DataFrame({"col1": ["row1", "row2", "row3"]}),
            "DF2": pd.DataFrame({"COL1": ["ROW1", "ROW2", "ROW3"]})
        }
        will return: [
            {"df1": {"col1": "row1"}, "DF2": {"COL1": "ROW1"}},
            {"df1": {"col1": "row2"}, "DF2": {"COL1": "ROW2"}},
            {"df1": {"col1": "row3"}, "DF2": {"COL1": "ROW3"}}
        ]

    :param df_dict: dictionary of dataframes to serialize as list
    :param original_index: index of original data, for realignment purpose
    :type df_dict: dic[str, pandas.DataFrame]
    :rtype: list
    """
    lst = []
    # will convert the df dict into a single dataframe with multi-level columns
    multi_indexed_df = pd.concat(df_dict, axis=1)
    for idx in original_index:
        if idx not in multi_indexed_df.index:
            lst.append(None)
        else:
            df_item = multi_indexed_df.loc[idx]
            # Creates a nested dict with the name of the dataframe first, then the actual columns
            lst.append({level: df_item.xs(level).dropna().to_dict() for level in df_item.index.levels[0]})
    return lst


@six.add_metaclass(ABCMeta)
class AbstractInteractiveScorer(object):

    @abstractmethod
    def score(self, records):
        pass

    @abstractmethod
    def explain(self, computation_params, records):
        pass

    @abstractmethod
    def compute_counterfactuals(self, computation_params, records):
        pass

    @abstractmethod
    def optimize_outcome(self, computation_params, records):
        pass


class InteractiveScorer(AbstractInteractiveScorer):
    def __init__(self, params):
        self.model_handler = build_model_handler(params["split_desc"], params["core_params"],
                                                 params["preprocessing_folder"], params["model_folder"],
                                                 params["split_folder"], params["fmi"])
        self.predictor = self.model_handler.get_predictor()
        self.per_feature = self.predictor.params.preprocessing_params['per_feature']
        self.dtypes = ml_dtypes_from_dss_schema(params["split_desc"]["schema"], self.per_feature,
                                                prediction_type=params["core_params"]["prediction_type"])

        if self.model_handler.supports_exploration():
            if self.model_handler.get_prediction_type() == constants.REGRESSION:
                computer = OutcomeOptimizationStrategy()
            else:
                computer = CounterfactualsStrategy(self.model_handler.get_target_map())
            self.emu_wrapper = DoctorEmuWrapper(self.model_handler, computer)

    def score(self, records):
        df = self._get_dataframe(records)
        pred_df = self.predictor.predict(df)
        return pred_df

    def explain(self, computation_params, records):
        prediction_df = self.predictor.predict(self._get_dataframe(records),
                                               with_probas=True,
                                               with_explanations=True,
                                               n_explanations=computation_params.get("nExplanations", 1),
                                               explanation_method=computation_params.get("explanationMethod"))
        # Extract explanations
        explanations_prefix = "explanations_"
        explanations_col = [col for col in prediction_df.columns if col.startswith(explanations_prefix)]
        explanations_df = prediction_df[explanations_col]

        prediction_df = prediction_df.drop(explanations_col, axis=1)

        new_col_names = map(lambda col: col.replace(explanations_prefix, ""), explanations_col)
        explanations_df.columns = new_col_names
        return prediction_df, explanations_df

    def compute_counterfactuals(self, computation_params, records):
        df = self._get_dataframe(records)
        if df.shape[0] != 1:
            raise ValueError("Cannot compute counterfactuals with multiple references")
        self.emu_wrapper.set_constraints(computation_params["featureDomains"], df)
        target = computation_params.get("target", None)
        self.emu_wrapper.update_generator(target)
        return [self.emu_wrapper.compute(df)]  # Must return a list, so wrapping results on a single record into a list

    def optimize_outcome(self, computation_params, records):
        df = self._get_dataframe(records)
        if df.shape[0] != 1:
            raise ValueError("Cannot optimize outcome with multiple references")
        self.emu_wrapper.set_constraints(computation_params["featureDomains"], df)
        target = computation_params.get("target", SpecialTarget.MIN)
        self.emu_wrapper.update_generator(target)
        return [self.emu_wrapper.compute(df)]  # Must return a list, so wrapping results on a single record into a list

    def _get_dataframe(self, records):
        # format records as {"feature1": array, "feature2: array} like in API Node python server
        all_features = {k for d in records for k in d.keys()}
        records_as_dict = {feature: [record[feature] if feature in record else None for record in records]
                           for feature in all_features}

        df = dataframe_from_dict_with_dtypes(records_as_dict, self.dtypes)
        df = add_missing_columns(df, self.dtypes, self.per_feature)

        return df


def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()

    interactive_model = InteractiveModelProtocol(link)
    interactive_model.start()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
