from flask import Blueprint, jsonify, request
from commons.python.business_solutions_api.dataiku_api import dataiku_api
from commons.python.business_solutions_api.dataset_api import fetch_filtered_dataiku_dataset
import dataiku
import numpy as np
import json
from dataikuapi.dss.job import DataikuException
import pandas as pd

from project.src.functions import (
    get_current_project_and_variables,
    compute_frontend_parameters,
    get_project_and_variables,
    get_user,
)

fetch_api = Blueprint("fetch_api", __name__, url_prefix="/api")

#app_variables, webapp_engine, demand_forecast_project = init_webapp()
project, variables = get_current_project_and_variables()
main_project_key= project.project_key

def get_run_info(run):
    return {"id":run.id,"running":run.running}

########## First api call example ############
@fetch_api.route("/hello", methods=["GET"])
def hello():
    return jsonify({"key": "hello"})


def get_dataset(dataset_name):
    # client = dataiku.api_client()
    # project = client.get_project(project_key)
    # dataset = project.get_dataset(dataset_name)
    dataset = dataiku.Dataset(name=dataset_name)
    return dataset


@fetch_api.route("/name", methods=["GET"])
def get_dataset_name():
    dataset_name = request.args.get("dataset")
    dataset = get_dataset(dataset_name=dataset_name)
    return dataset.name


@fetch_api.route("/dataset_schema", methods=["GET"])
def get_dataset_schema():
    dataset = get_dataset(dataset_name="results_per_granularity"
    )
    return dataset.get_schema()


@fetch_api.route("/dataset_data", methods=["GET"])
def get_dataset_data():
    dataset_name = request.args.get("dataset")
    project_key = request.args.get("project")
    if (project_key is None) or (project_key == ""):
        project_key = main_project_key
    df = get_dataset(dataset_name).get_dataframe()
    rows = df.to_dict(orient="records")
    # columns = [ { "name": col, "align": 'center', "label": col, "field": col, "sortable": True }  for col in df.columns]
    return {"rows": rows}


@fetch_api.route("/dataset_columns", methods=["GET"])
def get_dataset_columns():
    dataset_name = request.args.get("dataset")
    project_key = request.args.get("project")
    df = get_dataset(dataset_name).get_dataframe()
    columns = [
        {"name": col, "align": "center", "label": col, "field": col, "sortable": True}
        for col in df.columns
    ]
    return {"columns": columns}


@fetch_api.route("/summary", methods=["GET"])
def get_summary_data():
    dataset_name = "results_per_global_strategy_prepared"
    project_key = request.args.get("project")
    forecast_granularity = request.args.get("store")
    df = get_dataset(dataset_name).get_dataframe()
    df = (
        df[df["forecast_granularity"] == forecast_granularity]
        .drop(columns=["forecast_granularity"])
        .drop_duplicates()
    )
    df = df.set_index("global_strategy").transpose().reset_index()
    rename_dict_transpose = {
        "index": "KPI",
        "custom": "Markdown optimization",
        "zero": "No Markdown",
        "static": "Constant Markdown",
    }
    df = df.rename(columns=rename_dict_transpose)

    df["'No Markdown' Uplift"]  = (df["Markdown optimization"].div(df["No Markdown"])-1)*100
    df["'Constant Markdown' Uplift"]  = (df["Markdown optimization"].div(df["Constant Markdown"])-1)*100

    df["'No Markdown' Uplift"] = np.round(df["'No Markdown' Uplift"].fillna(0))
    df["'Constant Markdown' Uplift"] = np.round(df["'Constant Markdown' Uplift"].fillna(0))

    df.loc[df["KPI"].isin(["lost_sales","lost_profit"]), ["'No Markdown' Uplift","'Constant Markdown' Uplift"]] = df.loc[df["KPI"].isin(["lost_sales","lost_profit"]), ["'No Markdown' Uplift","'Constant Markdown' Uplift"]]*-1

    rows = df.to_dict(orient="records")
    columns = [
        {
            "name": col,
            "required": True,
            "align": "center",
            "label": col,
            "field": col,
            "sortable": True,
        }
        for col in df.columns
    ]

    return {"rows": rows, "columns": columns}


@fetch_api.route("/history", methods=["GET"])
def get_history_data():
    dataset_name = "historic_data"
    forecast_granularity = request.args.get("store")
    df = get_dataset(dataset_name).get_dataframe()

    df = df[df["forecast_granularity"] == forecast_granularity].drop(
        columns=["forecast_granularity"]
    )

    columns = df.to_dict(orient="list")
    formated_columns = []
    for col in columns.keys():
        if col == "reference_period_start_date":
            dates = [date.isoformat() for date in columns[col]]
        else:
            data = list(zip(dates, columns[col]))
            data = sorted(data, key=lambda pair: pair[0])
            data.sort
            loc_col = {"name": col, "data": data, "type": "line", "color": "#4578FC"}
            formated_columns.append(loc_col)

    df = get_dataset("aggregated_results").get_dataframe()
    df = df[df["forecast_granularity"] == forecast_granularity].drop(
        columns=["forecast_granularity"]
    )
    strategies_colors = {"zero": "#66418A", "custom": "#39561B", "static": "#F35B05"}
    strategies_names = {
        "custom": "Markdown optimization",
        "zero": "No Markdown",
        "static": "Constant Markdown",
    }
    forecast_date = (
        df["forecast_period_start_date"].drop_duplicates().iloc[0].isoformat()
    )
    # history_size = len(dates)
    dates.append(forecast_date)

    for strategy in strategies_colors.keys():
        loc_df = df[df["global_strategy"] == strategy].drop(columns=["global_strategy"])
        columns = loc_df.to_dict(orient="list")
        for col in columns.keys():
            if col == "forecast_period_start_date":
                pass
            else:
                data = [[forecast_date] + columns[col]]
                loc_col = {
                    "name": f"{col} {strategies_names[strategy]}",
                    "data": data,
                    "type": "line",
                    "color": strategies_colors[strategy],
                }
                formated_columns.append(loc_col)
    return {"history": formated_columns}


@fetch_api.route("/elasticity", methods=["POST"])
def get_elasticity_data():
    dataset_name = "price_elasticity_windows_prepared"
    df = get_dataset(dataset_name).get_dataframe()

    req_json = request.get_json(force=True)
    products = req_json["products"]
    profitVolumeToggle = req_json["profitVolumeToggle"]

    if products is None:
        print("No filter 1")
        products = []

    mask = df["product_id"].isin(products)
    # discounts = json.loads(df["discount"].loc[0])
    df = df[mask]

    formated_data = []

    rows = df.to_dict(orient="records")

    for row in rows:
        if profitVolumeToggle == "Profit":
            data = row["profit_uplift_zero"]
        else:
            data = row["volume_uplift_zero"]
        discounts = json.loads(row["discount"])
        data = json.loads(data)

        data = list(zip(discounts, data))
        data = sorted(data, key=lambda pair: pair[0])
        data.sort

        loc_col = {
            "name": row["product_id"],
            "data": data,
            "type": "line",
        }
        formated_data.append(loc_col)

    return {"elasticity": formated_data}


@fetch_api.route("/categories", methods=["GET"])
def get_categories():
    dataset_name = "product_categories"
    project_key = dataiku_api.get_project().project_key
    df = get_dataset(dataset_name).get_dataframe()

    # level_1_filter = request.args.get("level_1_filter")
    # mask = df["universe"].isin(level_1_filter)

    level_1 = df["universe"].drop_duplicates().to_list()
    level_2 = df["type"].drop_duplicates().to_list()

    return {"level_1": level_1, "level_2": level_2}

@fetch_api.route("/products", methods=["GET"])
def get_products():
    dataset_name = "sales_x_features_joined_by_product_id" #TODO update name with final dataset
    project_key = dataiku_api.get_project().project_key
    df = get_dataset(dataset_name).get_dataframe()

    product_id = df["product_id"].drop_duplicates().to_list()

    return {"products": product_id}

@fetch_api.route("/updaterow", methods=["POST"])
def updaterow():
    req_json = request.get_json(force=True)
    input_data = req_json["rows"]
    input_data[0]["override_value"] = int(input_data[0]["override_value"]) / 100

    user_login = req_json["username"]
    folder_input = dataiku.Folder("CuQeLXpI", project_key=main_project_key)
    path = f"/{user_login}/input.json"
    try:
        existing_rows = folder_input.read_json(path)
    except:
        existing_rows = []
        print("No existing Json")
    write_value = input_data + existing_rows
    folder_input.write_json(path, write_value)
    return "Writeback Done"


@fetch_api.route("/categories_aggregated", methods=["GET"])
def get_categories_agg():
    dataset_name = "product_categories_by_universe_prepared"
    df = get_dataset(dataset_name).get_dataframe()
    categories = df.to_dict(orient="records")

    reform_categories = {}
    for cat in categories:
        reform_categories[cat["universe"]] = json.loads(cat["type"])

    return reform_categories


@fetch_api.route("/build_all_update", methods=["POST"])
def build_all_update():
    req_json = request.get_json(force=True)
    input_data = req_json["rows"]
    headers = dict(request.headers)
    user_login = get_user(headers)
    for row in input_data:
        row["username"] = user_login
    folder_input = dataiku.Folder("CuQeLXpI", project_key=main_project_key)
    path = f"/{user_login}/input.json"
    try:
        existing_rows = folder_input.read_json(path)
    except:
        existing_rows = []
        print("No existing Json")
    write_value = input_data + existing_rows
    folder_input.write_json(path, write_value)

    project = dataiku_api.get_project()
    scenario = project.get_scenario("MKD11BUILD_ALL_FROM_UPDATE")
    try:
        trigger_fire = scenario.run()
        scenario_run = trigger_fire.wait_for_scenario_run()
        return scenario_run.id
    except DataikuException as error:
        print(error)
        return "Error"


@fetch_api.route("/scenario_status", methods=["GET"])
def get_scenario_status():
    scenario_id = request.args.get("scenario_id")
    if scenario_id == "Error":
        return "Not Running"
    scenario_name = request.args.get("scenario_name")
    project = dataiku_api.get_project()
    scenario = project.get_scenario(scenario_name)
    run = scenario.get_run(scenario_id)
    if run.running:
        return "Running"
    else:
        return run.outcome


@fetch_api.route("/load_frontend_parameters", methods=["GET"])
def load_frontend_parameters():
    print("Loading frontend parameters ...")

    df = get_dataset("product_strategy_synced").get_dataframe()
    product_list = df["product_id"].drop_duplicates().tolist()
    backend_response = compute_frontend_parameters(product_list
    )
    return backend_response


@fetch_api.route("/wiki", methods=["GET"])
def get_wiki_article():
    article_id = request.args.get("article_id")
    project = dataiku_api.get_project()
    wiki = project.get_wiki()
    content = wiki.get_article(article_id).get_data().get_body()
    return content

@fetch_api.route("/project_variable", methods=["GET"])
def get_project_variable():
    variable_name = request.args.get("variable_name")
    project_key = request.args.get("project")
    if (project_key is None) or (project_key == ""):
        project_key = main_project_key
    project = dataiku_api.get_project(project_key=project_key)
    variables = project.get_variables()
    variable_value = variables["standard"][variable_name]
    return variable_value

@fetch_api.route("/demand_forecast_time_scale", methods=["GET"])
def get_demand_forecast_time_scale():
    _, variables_demand = get_project_and_variables(
    variables["standard"]["demand_forecast_project_key_app"]
)
    forecast_time_granularity = variables_demand["standard"]['time_step']
    forecast_horizon = variables_demand["standard"]['forecast_horizon']
    dates_boundaries = get_dataset("dates_boundaries").get_dataframe()

    min_date = dates_boundaries["min_date"][0].strftime('%Y-%m-%d')
    max_date = dates_boundaries["max_date"][0].strftime('%Y-%m-%d')

    return {
        "forecast_time_granularity": forecast_time_granularity,
        "max_date": max_date,
        "forecast_horizon": forecast_horizon,
        "min_date":min_date
    }


@fetch_api.route("/optimization_results", methods=["GET"])
def get_optimization_results():
    products_dataframe = get_dataset(
        "writeback_webapp_synced_windows"
    ).get_dataframe()
    results_dataframe = get_dataset(
        "display_table_by_product_id_prepared"
    ).get_dataframe()

    filtered_results = pd.merge(
        results_dataframe,
        products_dataframe,
        on="product_id",
        how="inner",
        suffixes = ("","_old")
    )
    # replace NAN
    filtered_results = filtered_results.fillna(0)

    rows = filtered_results.to_dict(orient="records")
    return {"rows": rows}

@fetch_api.route("/last_run", methods=["GET"])
def get_last_optimization_run():
    project = dataiku_api.get_project()
    scenario = project.get_scenario("MKD11BUILD_ALL_FROM_UPDATE")
    last_run = scenario.get_last_runs(limit=1)
    if len(last_run)>0:
        return get_run_info(last_run[0])
    else:
        {"id":"","running":False}