# encoding: utf-8
"""
Single-thread hoster for a "legacy" scikit-learn based predictor
"""

import json
import logging
import os.path as osp
import numpy as np
import pandas as pd
import sys
import time
import traceback

try:
    import cPickle as pickle
except:
    import pickle

from dataiku.base.utils import watch_stdin, get_json_friendly_error
from dataiku.base.folder_context import build_noop_folder_context
from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.core import debugging
from dataiku.core.dkujson import load_from_filepath
from dataiku.core.saved_model import build_predictor_for_saved_model
from dataiku.core.saved_model import CausalPredictor
from dataiku.doctor import utils
from dataiku.doctor.utils import add_missing_columns
from dataiku.doctor.utils import dataframe_from_dict_with_dtypes
from dataiku.doctor.utils import doctor_constants

from dataiku.apinode.predict.timeseries import pred_to_dict_timeseries


logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
debugging.install_handler()


EXPLANATION_COL_PREFIX = "explanations_"

def pred_to_dict(pred_df, nb_records, explanations_cols=None, with_pred_intervals=False):
    pred_df = _add_ignored_records(pred_df, nb_records)
    results_columns = ["prediction", "ignored"]
    if with_pred_intervals:
        pred_df.rename(columns={
            "prediction_interval_lower": "predictionIntervalLower",
            "prediction_interval_upper": "predictionIntervalUpper",
        }, inplace=True)
        results_columns += ["predictionIntervalLower", "predictionIntervalUpper"]
    logging.info('pre to_dict: %s', pred_df)
    return _get_list_of_dicts(pred_df, results_columns, explanations_cols)


def pred_to_dict_causal_binary_treatment(pred_df, nb_records, with_propensity=False):
    pred_df = _add_ignored_records(pred_df, nb_records)
    pred_df.rename(columns={"predicted_effect": "predictedEffect"}, inplace=True)
    logging.info('pre to_dict: %s', pred_df)
    results_columns = ["predictedEffect", "ignored"]
    if with_propensity:
        results_columns.append("propensity")
    return _get_list_of_dicts(pred_df, results_columns)


def pred_to_dict_causal_multi_val_treatment(pred_df, nb_records, treatments, control, with_propensity=False):
    pred_df = _add_ignored_records(pred_df, nb_records)
    dicts = []
    for i, r in enumerate(pred_df.to_dict(orient="records")):
        if r['ignored']:
            dicts.append({'ignored': True, 'ignoreReason': "IGNORED_BY_MODEL"})
        else:
            tmp = {"predictedEffects": {}}
            for t in treatments:
                if t != control:
                    tmp["predictedEffects"][t] = r["predicted_effect_" + t]
            tmp["predictedBestTreatment"] = max(tmp["predictedEffects"].items(), key=lambda x: x[1])[0]
            if with_propensity:
                tmp["propensities"] = {}
                for t in treatments:
                    tmp["propensities"][t] = r["propensity_" + t]
            dicts.append(tmp)
    return dicts

def _add_ignored_records(pred_df, nb_records):
    pred_df.loc[:, 'ignored'] = False
    pred_df = pd.DataFrame(index=range(0, nb_records)).merge(pred_df, how='outer', left_index=True, right_index=True)
    pred_df.ignored.fillna(True, inplace=True)
    return pred_df


def _get_list_of_dicts(pred_df, results_columns, explanations_cols=None):
    explanations = None
    if explanations_cols is not None:
        explanations = pred_df[explanations_cols].to_dict(orient="records")
    dicts = []
    for i, r in enumerate(pred_df[results_columns].to_dict(orient="records")):
        if r['ignored']:
            dicts.append({'ignored': True, 'ignoreReason': "IGNORED_BY_MODEL"})
        else:
            if explanations:
                r["explanations"] = {k.replace(EXPLANATION_COL_PREFIX, "", 1): explanations[i][k]
                                     for k in explanations[i] if not np.isnan(explanations[i][k])}
            dicts.append(r)
    return dicts


def _build_dataframe(predictor, data, advanced_options):
    per_feature = predictor.params.preprocessing_params[doctor_constants.PER_FEATURE]
    dtypes = utils.ml_dtypes_from_dss_schema(data["schema"], per_feature,
                                             prediction_type=predictor.params.core_params.get(doctor_constants.PREDICTION_TYPE))

    if predictor.params.core_params.get("partitionedModel", {}).get("enabled", False):
        partition_cols = predictor.params.core_params.get("partitionedModel", {}).get("dimensionNames", [])
        if len(partition_cols) > 0:
            logging.info("Scoring partitioned model with partition columns: %s" % partition_cols)
            logging.info("Forcing their dtype to be 'str")
            for partition_col in partition_cols:
                if partition_col in dtypes.keys():
                    dtypes[partition_col] = "str"

    if advanced_options.get("dumpInputRecords", False):
        logging.info("Input dtypes: %s" % dtypes)

    records_df = dataframe_from_dict_with_dtypes(data["records"], dtypes)
    records_df = add_missing_columns(records_df, dtypes, per_feature)

    logging.info("Done preparing missing records")
    logging.info("Done preparing input DF")  #: %s" % records_df)

    if advanced_options.get("dumpInputDataFrame", False):
        logging.info("Input dataframe dump:\n%s" % records_df)
        for x in records_df.columns:
            logging.info("R0[%s] = %s" % (x, records_df[x][0]))
        logging.info("Input dataframe dtypes:\n%s" % records_df.dtypes)

    return records_df

def _has_colum_starting_with(df, prefix):
    return next((True for c in df.columns if c.startswith(prefix)), False)

# Data: {
#  records : {
#    Colname : [values]
# }
# schema : DSS schema (preparation output schema)

def handle_predict(predictor, request):
    ret = {}

    model_type = predictor.params.core_params.get("taskType")
    prediction_type = predictor.params.core_params.get(doctor_constants.PREDICTION_TYPE)

    advanced_options = request.get("pyPredictionAdvancedOptions", {})

    if advanced_options.get("dumpInputRecords", False):
        logging.info("Input records %s" % request["records"])

    IGNORED = {'ignored': True, 'ignoreReason': "IGNORED_BY_MODEL"}

    if not "schema" in request:
        raise Exception("Schema not specified")

    # build the dataframe to predict

    records_df = _build_dataframe(predictor, request, advanced_options)
    nb_records = records_df.shape[0]

    predictor._set_debug_options(advanced_options)
    before = time.time()
    if prediction_type == doctor_constants.TIMESERIES_FORECAST:
        pred_df = predictor.predict(records_df)
    else:
        pred_df = predictor.predict(records_df, with_proba_percentile=True, with_conditional_outputs=True)
    after = time.time()
    ret["execTimeUS"] = int(1000000 * (after - before))

    explanations_cols = None
    if request.get("explanations", {}).get("enabled"):
        before = time.time()

        # Re-running the full prediction to get the explanations to leverage the normalization of the data prior
        # to passing it to the model.
        # Besides, prediction time is negligible compared with time spent on computing explanations.
        pred_df = predictor.predict(records_df,
                                    with_proba_percentile=True,
                                    with_conditional_outputs=True,
                                    with_explanations=True,
                                    explanation_method=request.get("explanations").get("method"),
                                    n_explanations=request.get("explanations").get("nExplanations"),
                                    n_explanations_mc_steps=request.get("explanations").get("nMonteCarloSteps"))

        after = time.time()
        ret["explanationsTimeUS"] = int(1000000 * (after - before))
        explanations_cols = [c for c in pred_df.columns if c.startswith(EXPLANATION_COL_PREFIX)]

    if model_type == "PREDICTION":
        kind = 'prediction'
        if prediction_type == 'REGRESSION':
            kind = 'regression'
        elif prediction_type in {'BINARY_CLASSIFICATION', 'MULTICLASS', "DEEP_HUB_IMAGE_CLASSIFICATION"}:
            kind = 'classification'
        elif prediction_type == "DEEP_HUB_IMAGE_OBJECT_DETECTION":
            kind = "objectDetection"
        elif prediction_type == doctor_constants.TIMESERIES_FORECAST:
            kind = "timeseriesForecasting"
        elif prediction_type in doctor_constants.CAUSAL_PREDICTION_TYPES:
            kind = "causalPrediction"
        else:
            # TODO: should the code fail for an unknown prediction type?
            logging.info("unknown prediction type: " + prediction_type)
    elif model_type == "CLUSTERING":
        kind = "clustering"
    else:
        logging.info("unknown model type: " + model_type)


    logging.info("Done predicting, shape=%s" % str(pred_df.shape))
    if pred_df.shape[0] == 0 and prediction_type != doctor_constants.TIMESERIES_FORECAST:
        logging.info("Empty dataframe post processing")
        return {kind: [IGNORED for i in range(0, nb_records)]}

    pred_idx = pred_df.index
    if model_type == "PREDICTION":
        if prediction_type == doctor_constants.TIMESERIES_FORECAST:
            ret[kind] = pred_to_dict_timeseries(pred_df, predictor.params)
        elif prediction_type in doctor_constants.CAUSAL_PREDICTION_TYPES:
            if predictor.params.core_params["enable_multi_treatment"] and len(predictor.params.core_params["treatment_values"]) > 2:
                treatments = predictor.params.core_params["treatment_values"]
                if predictor.params.preprocessing_params["drop_missing_treatment_values"]:
                    treatments = [t for t in treatments if t != ""]
                control = predictor.params.core_params["control_value"]
                ret[kind] = pred_to_dict_causal_multi_val_treatment(pred_df, nb_records, treatments, control, predictor.with_propensity)
            else:
                ret[kind] = pred_to_dict_causal_binary_treatment(pred_df, nb_records, predictor.with_propensity)
        else:
            with_pred_intervals = kind == 'regression' and _has_colum_starting_with(pred_df, "prediction_interval_")
            ret[kind] = pred_to_dict(pred_df, nb_records, explanations_cols=explanations_cols, with_pred_intervals=with_pred_intervals)
        if kind == "classification":
            has_probas = _has_colum_starting_with(pred_df, "proba_")
            # Fairly ugly ...
            if has_probas:
                record_dicts = pred_df.to_dict(orient='records')
                for (record, i) in zip(record_dicts, pred_idx):
                    entry = ret[kind][i]
                    entry["probas"] = {c: record["proba_%s" % c] for c in predictor.get_classes()}

                    if prediction_type == "BINARY_CLASSIFICATION":
                        entry["probaPercentile"] = record.get("proba_percentile", None)
                        cos = predictor.get_conditional_output_names()
                        if len(cos) > 0:
                            entry["conditionals"] = {co: record[co] for co in cos}
    elif model_type == "CLUSTERING":
        ret[kind] = pred_df.to_dict(orient="records")

    return ret


# socket-based connection to backend
def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    # initiate connection
    link.connect()
    
    # get work to do
    try:
        # retrieve the initialization info and initiate serving
        command = link.read_json()
        model_folder = command.get('modelFolder')
        endpoint_with_explanations = command.get("outputExplanations")
        try:
            conditional_outputs = load_from_filepath(osp.join(model_folder, "conditional_outputs.json"))
        except Exception as e:
            logging.exception("Can't load conditional outputs: " + str(e))
            conditional_outputs = []
        # If we are here we know it is a DSS_MANAGED saved model type. An MLflow model would be served by dataiku.apinode.predict.mlflowpyfuncserver
        predictor = build_predictor_for_saved_model(build_noop_folder_context(model_folder), "DSS_MANAGED", conditional_outputs)
        if isinstance(predictor, CausalPredictor):
            predictor.with_propensity = command.get("computePropensity", False)
        if endpoint_with_explanations:
            predictor.ready_explainer()
        logging.info("Predictor ready")
        link.send_json({"ok":True})

        stored_exception = None
        # loop and process commands
        while True:
            request = link.read_json()
            if request is None:
                break

            response = handle_predict(predictor, request)

            link.send_json(response)
            
        # send end of stream
        logging.info("Work done")
        link.send_string('')
    except:
        logging.exception("Prediction user code failed")
        link.send_string('') # send null to mark failure
        link.send_json(get_json_friendly_error())
    finally:
        # done
        link.close()


if __name__ == "__main__":
    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
  
