import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
import tensorflow as tf
import tensorflow_probability as tfp
import arviz as az
import io
import pickle
import joblib
from dataiku import insights

from meridian import constants
from meridian.data import load
from meridian.data import test_utils
from meridian.model import model
from meridian.model import spec
from meridian.model import prior_distribution
from meridian.analysis import optimizer
from meridian.analysis import analyzer
from meridian.analysis import visualizer
from meridian.analysis import summarizer
from meridian.analysis import formatter

import time


def run_what_if_exploration(ranges,constraint_type, target_roi=None, new_budget=None,gtol=0.001):

    # DEFINE OUTPUT FOLDER AND DATASETS
    project = dataiku.api_client().get_default_project()
    variables = project.get_variables()


    managed_folder_id = "ajsbTNVY"
    managed_folder = dataiku.Folder(managed_folder_id)

    with managed_folder.get_download_stream("model.pkl") as f:
        data = f.read()
        model = pickle.loads(data)
    f.close()
    
    if constraint_type == "roi":
        new_budget = None
    elif constraint_type == "budget":
        target_roi = None
    else:
        target_roi = None
        new_budget = variables['standard']["total_current_budget"]
    
    print(ranges)
    low= [-(pair[0]/100) if pair[0]/100 <0 else 0 for pair in ranges]
    high= [(pair[1]/100) if pair[1]/100 >0  else 0 for pair in ranges]

    budget_optimizer = optimizer.BudgetOptimizer(model)

    print('start optim')
    start_time = time.time()

    optimization_results = budget_optimizer.optimize(
        selected_times=[variables['standard']["min_date_eval_app"], variables['standard']["max_date_eval_app"]],
        budget=new_budget,
        fixed_budget=target_roi is None,
        target_roi=target_roi,
        spend_constraint_lower=low,
        spend_constraint_upper=high,
        batch_size=500,
        gtol = gtol,
    )

    print("--- %s seconds ---" % (time.time() - start_time))
    print('start insight')
    start_time = time.time()

    optim_file = optimization_results._gen_optimization_summary()


    insights.save_data(
        "optim_scenario",
        payload=optim_file,
        content_type="text/html",
        label=None,
        project_key=None,
        encoding=None
    )
    print("--- %s seconds ---" % (time.time() - start_time))

    
    df=pd.DataFrame(list(map(int,optimization_results.optimized_data.spend)),
                    index=optimization_results.optimized_data.channel,
                    columns=["Custom Budget"])
    df["Custom Sales"]= int(optimization_results.optimized_data.total_incremental_outcome)

    
    df["Current Budget"]= variables['standard']["current_spend"]
    df["Current Sales"]= variables['standard']["current_sales"]

    df["Reference Budget"]= variables['standard']["opt_spend"]
    df["Reference Sales"]= variables['standard']["opt_target"] 
    
    
    
    # -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
    df.reset_index(inplace=True)
    df.rename(columns={'index': 'Media Channels'},inplace=True)
    
    return df