from flask import Blueprint, jsonify, request
from .python import graphs_utils as gu
from .python import constants
import numpy as np
import pandas as pd
import json
import datetime
import dataiku
import dataikuapi
import math
from commons.python.fetch.config_bs import ConfigBs, EnvMode
import logging
import threading
import time

mode = ConfigBs.mode()

fetch_api = Blueprint("fetch_api",__name__)
client = dataiku.api_client()
if mode == EnvMode.LOCAL.value:
    project_key = "SOL_BATCH_PERF_OPTIM"
else:
    project_key = dataiku.get_custom_variables()["projectKey"]

client = dataiku.api_client()
project = client.get_project(project_key)
variables = project.get_variables()

def to_json(o):
    # np.int64 do not inherit from int so it needs a special handling
    if isinstance(o, np.int64):
        return int(o)
    return o.__dict__

batch_data_by_attribute = dataiku.Dataset('batch_data_by_attribute', project_key=project_key).get_dataframe()

## Load all datasets asynchronously
batch_data = None
def load_batch_dataset():
    global batch_data
    batch_data = dataiku.Dataset('batch_data_prepared', project_key=project_key).get_dataframe()
    batch_data[constants.START_TIME] = pd.to_datetime(batch_data[constants.START_TIME]).dt.floor('D')
load_batch_dataset_thread = threading.Thread(target=load_batch_dataset)
load_batch_dataset_thread.start()


sensor_data = None
def load_sensor_dataset():
    global sensor_data
    sensor_data = dataiku.Dataset('sensor_data_analysis', project_key=project_key).get_dataframe()
    sensor_data[constants.TIMESTAMP] = pd.to_datetime(sensor_data[constants.TIMESTAMP])
load_sensor_dataset_thread = threading.Thread(target=load_sensor_dataset)
load_sensor_dataset_thread.start()

sensor_data = None
pred_proba = None
sensor_data_for_interpretation = None
batch_and_sensor_for_prediction = None
def load_prediction_dataset():
    global pred_proba
    global sensor_data_for_interpretation
    global batch_and_sensor_for_prediction
    pred_proba = dataiku.Dataset('next_batch_scored', project_key=project_key).get_dataframe()
    sensor_data_for_interpretation = dataiku.Dataset('sensor_data_for_interpretation', project_key=project_key).get_dataframe()
    batch_and_sensor_for_prediction = dataiku.Dataset('data_for_prediction', project_key=project_key).get_dataframe()
load_prediction_dataset_thread = threading.Thread(target=load_prediction_dataset)
load_prediction_dataset_thread.start()

########################### BATCH PAGE ########################### 

def getDashboardBatchPage(batch_data, selected_filters):
    colors = {"success" : '#27a300', "failure" : '#E45545'}
    main_attribute = constants.EQUIPMENT_ID
    for attribute in variables["standard"]["attribute_names"]:
        if attribute != constants.EQUIPMENT_ID:
            main_attribute = attribute
            break
    
    batch_data['count'] = 1

    main_attribute_count_series = gu.generateStackedSeries(batch_data, constants.FAILURE_STR, main_attribute, 'count', 'sum', type='bar', sort=True, colors=colors)
    count_chart = gu.XYChart('Success rate by %s'%main_attribute, 'category', 'value', main_attribute_count_series)
    batch_data[constants.START_TIME] = batch_data[constants.START_TIME].astype(str)

    batch_data.sort_values(by=constants.START_TIME, inplace=True)
    date_count_series = gu.generateStackedSeries(batch_data, constants.FAILURE_STR, constants.START_TIME, 'count', 'sum', type='bar', colors=colors)
    count_over_time_chart = gu.XYChart('Success rate over time', 'category', 'value', date_count_series, dataZoom=True)

    duration_series = gu.generateStackedSeries(batch_data, constants.FAILURE_STR, main_attribute, 'batch_duration_minutes', 'mean', type='bar', stack=False, sort=True, colors=colors)
    duration_chart = gu.XYChart('Average batch duration by %s'%main_attribute, 'category', 'value', duration_series, graph_subtitle='(in minutes)')
    dashboard = gu.Dashboard('', [count_chart, count_over_time_chart, duration_chart])
    return dashboard.__dict__

def getFiltersBatchPage(batch_data):
    if(batch_data.shape[0] == 0):
        return [{"name" : constants.EQUIPMENT_ID, "values" : {}, "type" : "string"}, {"name" : constants.START_TIME, "values" : {}, "type" : "date"}]

    filters = []
    equipment_ids = sorted(batch_data[constants.EQUIPMENT_ID].unique())
    filters.append({"name" : constants.EQUIPMENT_ID, "values" : equipment_ids, "type" : "string"})

    min_date = min(batch_data[constants.START_TIME])
    max_date = max(batch_data[constants.START_TIME]) + datetime.timedelta(days=1)
    filters.append({"name" : constants.START_TIME, "values" : {"from":min_date.strftime('%Y/%m/%d'), "to" : max_date.strftime('%Y/%m/%d')}, "type" : "date"})
    return filters

@fetch_api.route("/api/getBatchPage", methods=["POST"])
def getBatchPage():
    global batch_data
    load_batch_dataset_thread.join()

    selected_filters = request.get_json(force=True)
    filters = getFiltersBatchPage(batch_data) # compute filters before filtering

    batch_data_copy = batch_data.copy()
    # filter the data based on user selection
    for filter_id in selected_filters.keys():
        if filter_id == constants.EQUIPMENT_ID:
            batch_data_copy = batch_data_copy[batch_data_copy[filter_id].isin(selected_filters[filter_id])]
        elif filter_id == constants.START_TIME and selected_filters[filter_id] != None:
            batch_data_copy = batch_data_copy[batch_data_copy[filter_id].between(selected_filters[filter_id]["from"], selected_filters[filter_id]["to"])]

    # compute metrics
    nb_days = (batch_data_copy["start_time"].max() - batch_data_copy["start_time"].min()).days
    nb_equipments = batch_data_copy["equipment_id"].nunique()
    nb_batches = len(batch_data_copy)
    metrics = [{"name":'Monitoring Duration (days)', "value" : math.ceil(nb_days)}, {"name":'Number of equipment', "value" : nb_equipments}, {"name" : 'Number of batches', "value" : nb_batches}, {"name" : 'Number of batches/equipment & day', "value" : math.ceil(nb_batches/(nb_days*nb_equipments))}]

    dashboard = getDashboardBatchPage(batch_data_copy, selected_filters)
    return json.dumps({"filters" : filters, "dashboard" : dashboard, "metrics" : metrics}, default=to_json)


########################### SENSOR PAGE ########################### 
 
def getDashboardSensorPage(sensor_data, selected_filters):
    ignored_columns = constants.BATCH_BASE_DATA_MODEL.copy()
    ignored_columns.extend([constants.FAILURE_STR, "timestamp", "timestamp_resampled"])
    ignored_columns.extend(variables["standard"]["attribute_names"])
    colors = {"success" : '#27a300', "failure" : '#E45545'}
    shades_of_colors = ["#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5"]


    groupby_column = constants.FAILURE_STR
    sensor_data_copy = sensor_data.copy()
    for filter_id in selected_filters.keys():
        if filter_id == constants.BATCH_ID and selected_filters[filter_id]: 
            sensor_data_copy = sensor_data_copy[sensor_data_copy[filter_id].isin(selected_filters[filter_id])]
        elif filter_id == constants.TIMESTAMP and selected_filters[filter_id]:
            sensor_data_copy = sensor_data_copy[sensor_data_copy[filter_id].between(selected_filters[filter_id]["from"], selected_filters[filter_id]["to"])]
        elif filter_id == "groupby_failure":
            groupby_column = constants.FAILURE_STR if selected_filters[filter_id] == True else constants.BATCH_ID

    charts = []
    for sensor_column in sensor_data_copy.columns:
        if sensor_column not in ignored_columns:
            line_series = []
            line_series = gu.generateLineSeries(sensor_data_copy[["timestamp_resampled", sensor_column, groupby_column]], "timestamp_resampled", sensor_column, groupby_column, nb_bin=50, colors=colors)
            chart = gu.XYChart(sensor_column.replace("_sensor_value_avg", ""), 'value', 'value', line_series, graph_subtitle='Average value of the sensor', dataZoom=True)
            chart.option["color"] = shades_of_colors
            charts.append(chart)

    dashboard = gu.Dashboard('', charts)
    return dashboard.__dict__

def getFiltersSensorPage(sensor_data):
    if(sensor_data.shape[0] == 0):
        return [{"name" : "groupby_failure", "values" : True, "type" : "boolean", "label" : "Group by Failure"},{"name" : constants.BATCH_ID, "values" : {}, "type" : "string"}, {"name" : constants.TIMESTAMP, "values" : {}, "type" : "date"}]

    filters = []
    filters.append({"name" : "groupby_failure", "values" : True, "type" : "boolean", "label" : "Group by Outcome / by Batch"})

    batch_ids = sorted(sensor_data[constants.BATCH_ID].unique())
    filters.append({"name" : constants.BATCH_ID, "values" : batch_ids, "type" : "string"})

    min_date = min(sensor_data[constants.TIMESTAMP])
    max_date = max(sensor_data[constants.TIMESTAMP]) + datetime.timedelta(days=1)
    filters.append({"name" : constants.TIMESTAMP, "values" : {"from":min_date.strftime('%Y/%m/%d'), "to" : max_date.strftime('%Y/%m/%d')}, "type" : "date"})

    return filters

@fetch_api.route("/api/getSensorPage", methods=["POST"])
def getSensorPage():
    global sensor_data
    load_sensor_dataset_thread.join()

    selected_filters = request.get_json(force=True)
    
    filters = getFiltersSensorPage(sensor_data)
    dashboard = getDashboardSensorPage(sensor_data, selected_filters)
    return json.dumps({"filters" : filters, "dashboard" : dashboard}, default=to_json)

@fetch_api.route("/api/getSensorScenarioParams")
def getSensorScenarioParams():
    global batch_data_by_attribute

    params = []
    for i in range(len(variables["standard"]["attribute_names"])):
        var_name = variables["standard"]["attribute_names"][i]
        #available_values = sorted(batch_data_by_attribute[var_name].unique())
        available_values = sorted(batch_data_by_attribute[var_name].unique(), key=str)
        available_values = [val for val in available_values if not pd.isna(val)]
        params.append({"name" : var_name, "values" : available_values, "type" : "string", "selection" : variables["standard"]["attribute_filters"][i]})
    return json.dumps(params, default=to_json)

@fetch_api.route("/api/setSensorScenarioParams", methods=["POST"])
def setSensorScenarioParams():
    global sensor_data
    selected_attributes = request.get_json(force=True)
    variables["standard"]["attribute_filters"] = [selected_attributes[x] for x in variables["standard"]["attribute_names"]]
    project.set_variables(variables)
    
    scenario = project.get_scenario("BUILDDASHBOARD")
    try:
        scenario.run_and_wait()
        load_sensor_dataset()
    except Exception as e:
        return json.dumps({'success':False}), 500, {'ContentType':'application/json'} 
        
    return json.dumps({'success':True}), 200, {'ContentType':'application/json'} 



########################### PREDICTION PAGE ########################### 

@fetch_api.route("/api/getPredictionScenarioParams")
def getPredictionScenarioParams():
    global batch_data_by_attribute
    
    params = []
    for i in range(len(variables["standard"]["attribute_names"])):
        var_name = variables["standard"]["attribute_names"][i]
        #available_values = sorted(batch_data_by_attribute[var_name].unique())
        available_values = sorted(batch_data_by_attribute[var_name].unique(), key=str)
        available_values = [val for val in available_values if not pd.isna(val)]
        params.append({"name" : var_name, "values" : available_values, "type" : "string", "selection" : variables["standard"]["attribute_prediction"][i]})
    return json.dumps(params, default=to_json)

@fetch_api.route("/api/setPredictionScenarioParams", methods=["POST"])
def setPredictionScenarioParams():
    selected_attributes = request.get_json(force=True)
    variables["standard"]["attribute_prediction"] = [selected_attributes[x] for x in variables["standard"]["attribute_names"]]
    project.set_variables(variables)
    
    scenario = project.get_scenario("PREDICT")
    try:
        scenario.run_and_wait()
        load_prediction_dataset()
    except Exception as e:
        return json.dumps({'success':False}), 500, {'ContentType':'application/json'} 

    return json.dumps({'success':True}), 200, {'ContentType':'application/json'} 

def normalize(x):
    return int(min(100, math.pow(math.sin(x * 3.14 / 2), 0.7)*100)) # x/1 to remember the normalization


@fetch_api.route("/api/getPrediction")
def getPrediction():
    global pred_proba
    global sensor_data_for_interpretation
    global batch_and_sensor_for_prediction
    load_prediction_dataset_thread.join()
    
    # prediction
    failure_risk = int(round(pred_proba["proba_1"]*100))

    # explanations
    shapley_values = json.loads(pred_proba['explanations'].iloc[0])
    shapley_values = sorted(shapley_values.items(), key=lambda item: float(item[1]), reverse=True)
    sensor_data_for_interpretation_success = sensor_data_for_interpretation[sensor_data_for_interpretation["failure_last"] == 0]
    sensor_data_for_interpretation_failure = sensor_data_for_interpretation[sensor_data_for_interpretation["failure_last"] == 1]
    next_failure_rate = 0
    if sensor_data_for_interpretation_success.shape[0]==0:
        next_failure_rate = 100
    elif sensor_data_for_interpretation_failure.shape[0]==0:
        next_failure_rate = 0
    else:
        next_failure_rate = sensor_data_for_interpretation_failure["count"].iloc[0] / (sensor_data_for_interpretation_success["count"].iloc[0] + sensor_data_for_interpretation_failure["count"].iloc[0])
    
    ## get last batch attributes
    attributes = variables["standard"]["attribute_names"]
    batch_attributes = []
    batch_attributes.append({"name" :"equipment_id", "value" : batch_and_sensor_for_prediction["equipment_id"].iloc[0]})
    batch_attributes.append({"name" :"batch_id", "value" : batch_and_sensor_for_prediction["batch_id_first"].iloc[0]})
    for attr in attributes:
        if attr != "equipment_id":
            batch_attributes.append({"name" : attr, "value" : batch_and_sensor_for_prediction["%s_first" %attr].iloc[0]})

    explanations = {"next_failure_rate": next_failure_rate, "batch_params" : [], "sensor_data" : [], "batch_attributes": batch_attributes}
    
    for pair in shapley_values:
        if pair[1] > 0 :
            col_name = pair[0]
            if "_sensor_value_" in col_name:
                real_value = f'{batch_and_sensor_for_prediction[col_name].iloc[0]:.2f}'
                avg_success_value = ""
                if sensor_data_for_interpretation_success.shape[0] != 0:
                    avg_success_value =f'{sensor_data_for_interpretation_success["%s_avg" %col_name].iloc[0]:.2f}'
                order = "higher" if float(real_value) > float(avg_success_value) else "lower"
                sensor_name = col_name.split("_sensor_value")[0]
                name_explanation = ""
                aggregation = ""
                if "avg_min" in col_name:
                    name_explanation = "minimum value"
                    aggregation = "min"
                elif "avg_max" in col_name:
                    name_explanation = "maximum value"
                    aggregation = "max"
                elif "avg_avg" in col_name:
                    name_explanation = "average value"
                    aggregation = "avg"
                elif "avg_stddev" in col_name:
                    name_explanation = "variation"
                    aggregation = "stddev"
                elif "avg_count" in col_name:
                    name_explanation = "number of emitted values"
                    aggregation = "count"
                
                if (real_value =="nan"):
                    order = "missing"
                if(avg_success_value =="nan") or (avg_success_value ==""):
                    order = "None"
                explanations["sensor_data"].append({"sensor_name" : sensor_name, "name_explanation": name_explanation, "order":order, "real_value":real_value, "avg_success_value":avg_success_value, "shapley_value": normalize(pair[1]), "aggregation":aggregation})

    return json.dumps({"explanations": explanations, "failure_risk" : failure_risk}, default=to_json)
