import dataiku
import dash
from dash import dcc
import dash_bootstrap_components as dbc
from dash import html, callback_context
from dash.dependencies import Input, Output, State, MATCH, ALL
from dash import dash_table
import pandas as pd
import numpy as np
import dataikuapi
import generalized_linear_models
sys.modules['generalized_linear_models'] = generalized_linear_models

app.config.external_stylesheets = [dbc.themes.BOOTSTRAP]
use_api = dataiku.get_custom_variables()["use_api"]
api_node_url = dataiku.get_custom_variables()["api_node_url"]

if use_api == "True":
    client = dataikuapi.APINodeClient(api_node_url, "claim_risk")
else:
    claim_frequency = dataiku.Model("aHJZVrBQ")
    claim_frequency_predictor = claim_frequency.get_predictor()
    claim_severity = dataiku.Model("j1Mpq3TM")
    claim_severity_predictor = claim_severity.get_predictor()
    pure_premium = dataiku.Model("PVKl9xt0")
    pure_premium_predictor = pure_premium.get_predictor()

def process_data(data):
    data = pd.DataFrame.from_dict(data={k: [data[k]] for k in data})
    data['VehAge'] = data['VehAge'].clip(upper=20)
    data['DrivAge'] = data['DrivAge'].clip(lower=20, upper=90)
    data['VehPower'] = data['VehPower'].clip(upper=9).astype(str)
    data['VehAgeBin'] = pd.cut(data['VehAge'], bins=[0, 1, 10, 100], right=False, labels=['0 : 1', '1 : 10', '10 : 100'])
    data['DrivAgeBin'] = pd.cut(data['DrivAge'],  bins=[20, 21, 26, 31, 41, 51, 71, 100], right=False,
                                labels=['20 : 21', '21 : 26', '26 : 31', '31 : 41', '41 : 51', '51 : 71', '71 : 100'])
    data['LogDensity'] = np.log10(data['Density'])
    data['LogDensityBin'] = pd.cut(data['LogDensity'], right=False, bins=list(range(6)), labels=[str(i) for i in range(5)])
    data['BonusMalus'] = data['BonusMalus'].clip(upper=150)
    data['LogBonusMalus'] = np.log10(data['BonusMalus'])
    BonusMalusBins = list(range(50, 160, 10))
    data['BonusMalusBin'] = pd.cut(data['BonusMalus'], right=False, bins=BonusMalusBins, labels=[str(i) for i in BonusMalusBins[:-1]])
    return data

def make_prediction(predictor, data):
    prediction = predictor.predict(data)
    return prediction

def slider(min_value, max_value, step, marks, default, slider_id, title, description):
    return dbc.Row([dbc.Col([html.H6(title, style={'margin-bottom': '0em'}), 
                             html.P(description, style={'color': '#8f8f8f', 'font-size': '12px'})], md=6), 
                    dbc.Col(dcc.Slider(min_value,
                     max_value,
                     step,
                     marks=marks,
                     value=default,
                     id=slider_id,
    tooltip={"placement": "bottom", "always_visible": True}), md=6)], style={'margin-bottom': '0em'})    

def dropdown(options, default, dropdown_id, title, description):
    return dbc.Row([dbc.Col([html.H6(title, style={'margin-bottom': '0em'}), 
                             html.P(description, style={'color': '#8f8f8f', 'font-size': '12px'})], md=6),
                    dbc.Col(dcc.Dropdown(id=dropdown_id,
                       options=[{'label': value, 
                                 'value': value}
                         for value in options],
                       value=default), md=6, style={'font-size': '12px'})], 
                   style={'margin-bottom': '0em'})

vehicle_power_slider = slider(4, 15, 1,
                                   marks={'4': 4,
                     '8': 8,
                     '12': 12,
                     '15': 15},
               default=4,
               slider_id='vehicle-power',
              title='パワー',
              description='車両のパワー、４〜１５。１５が最大')

vehicle_age_slider = slider(0, 20, 1,
                                 marks={'0': 0,
                     '5': 5,
                     '10': 10,
                     '15': 15,
                     '20': 20},
               default=0,
               slider_id='vehicle-age',
              title='年数',
           description='車両の年数、２０年が最大値')

driver_age_slider = slider(18, 99, 1,
                               marks={'18': 18,
                     '30': 30,
                     '60': 60,
                     '90': 90},
               default=42,
               slider_id='driver-age',
              title='年齢',
              description='ドライバーの年齢')

bonus_malus_slider = slider(50, 150, 1,
               marks={'50': 50,
                     '100': 100,
                     '150': 150},
               default=50,
               slider_id='bonus-malus',
               title='ディスカウント',
               description='過去に請求に応じたディスカウント、50に近いことが望ましい。最大値は150')

vehicle_brand_dropdown = dropdown(
                    ['B' + str(brand) for brand in [*range(1, 7), *range(10, 15)]],
                    default='B1',
                    dropdown_id='vehicle-brand',
                    title='車種',
                    description='車両の車種')

vehicle_gas_dropdown = dropdown(
                        ['Regular', 'Diesel'],
                    default='Regular',
                    dropdown_id='vehicle-gas',
                    title='燃料',
                    description='燃料の種類')


density_slider = slider(0, 27000, 1,
               marks={'100': 100,
                     '1000': 1000,
                     '10000': 10000},
               default=100,
               slider_id='density',
              title='密度',
              description='住んでいる地域の人口密度。１キロ平方メートルあたりの居住者')

region_dropdown = dropdown(
                        ['Alsace',
                         'Aquitaine',
                         'Auvergne',
                         'Basse-Normandie',
                         'Bourgogne',
                         'Bretagne',
                         'Centre',
                         'Champagne-Ardenne',
                         'Corse',
                         'Franche-Comte',
                         'Haute-Normandie',
                         'Ile-de-France',
                         'Languedoc-Roussillon',
                         'Limousin',
                         'Midi-Pyrenees',
                         'Nord-Pas-de-Calais',
                         'Pays-de-la-Loire',
                         'Picardie',
                         'Poitou-Charentes',
                         'Provence-Alpes-Cote-D\'Azur',
                         'Rhone-Alpes'],
                    default='Centre',
                    dropdown_id='region',
                    title='地域',
                    description='顧客の地域、2016年以前に定義されていた地域を選択')

def render_prediction(name, value, description):
    return dbc.Row([
            dbc.Col([html.H6(name + ": ", style={'text-align': 'left', 'margin-bottom': '0em'}),
                    html.P(description, style={'color': '#8f8f8f', 'font-size': '12px'})], md=6),
            dbc.Col(html.H6(str(round(value, 2)), 
                    style={'margin-bottom': '2em', 'text-align': 'left', 'color': '#3b99fc'}), md=6)])

@app.callback(Output('result', 'children'),
              Input('vehicle-power', 'value'),
              Input('vehicle-age', 'value'),
              Input('driver-age', 'value'),
              Input('bonus-malus', 'value'),
              Input('vehicle-brand', 'value'),
              Input('vehicle-gas', 'value'),
              Input('density', 'value'),
              Input('region', 'value'))
def render_results(vehicle_power, vehicle_age,
                  driver_age, bonus_malus,
                  vehicle_brand, vehicle_gas,
                  density, region):
    if use_api == "True":
        record_to_predict = {
            "Exposure": "1",
            "ClaimNb": "1",
            "VehPower": str(vehicle_power),
            "VehAge": str(vehicle_age),
            "DrivAge": str(driver_age),
            "BonusMalus": str(bonus_malus),
            "VehBrand": str(vehicle_brand),
            "VehGas": str(vehicle_gas),
            "Density": str(density),
            "Region": str(region)
        }
        claim_frequency = client.predict_record("claim_frequency", record_to_predict)['result']['prediction']
        claim_severity = client.predict_record("claim_severity", record_to_predict)['result']['prediction']
        pure_premium = client.predict_record("pure_premium", record_to_predict)['result']['prediction']
    else:
        record_to_predict = {
            "Exposure": 1,
            "ClaimNb": 1,
            "VehPower": vehicle_power,
            "VehAge": vehicle_age,
            "DrivAge": driver_age,
            "BonusMalus": bonus_malus,
            "VehBrand": vehicle_brand,
            "VehGas": vehicle_gas,
            "Density": density,
            "Region": region
        }
        processed_data = process_data(record_to_predict)
        claim_frequency = make_prediction(claim_frequency_predictor, processed_data)['prediction'].iloc[0]
        claim_severity = make_prediction(claim_severity_predictor, processed_data)['prediction'].iloc[0]
        pure_premium = make_prediction(pure_premium_predictor, processed_data)['prediction'].iloc[0]
    output = dbc.Col([
        render_prediction('請求件数', claim_frequency, '予測される請求件数'),
        render_prediction('請求額', claim_severity, '請求額の予測（ユーロ）'),
        render_prediction('複合純保険料', claim_frequency*claim_severity, '複合モデルを使った請求額の予測（ユーロ）'),
        render_prediction('Tweedie純保険料', pure_premium, 'Tweedieモデルを使った請求額の予測（ユーロ）')])
    return output

# build your Dash app
app.layout = dbc.Container(
    dbc.Col([
        html.H3("請求モデルアプリケーション", style={'margin-bottom': '0em'}),
        html.P("このインタラクティブな画面上で、ユーザーはモデルの入力パラメータを変更し、その結果を確認することができます。"+
               "コンパウンドモデルとTweedieモデルを並べて比較しています。 " + 
               "モデルがAPIノードにプッシュされている場合はモデルをAPIコールで呼び出すことができ、フローにデプロイされたモデルから直接呼び出すことも可能です。",
              style={'color': '#8f8f8f', 'font-size': '14px', 'margin-top': '0em'}),
        html.Hr(),
        dbc.Row([
            dbc.Col([vehicle_power_slider,
                    vehicle_age_slider,
                    driver_age_slider,
                    bonus_malus_slider,
                    vehicle_brand_dropdown,
                    vehicle_gas_dropdown,
                    density_slider,
                    region_dropdown], md=6,
                   style={'border-right': '1px solid',
                         'border-right-color': '#e3e4e4'}),
            dbc.Col([dbc.Container(id='result'),
                   html.Div([html.H6("説明", style={'font-size': '15px'}),
                             html.P("この4つの予測では、モデルの特徴は左側で設定したもので、1年間のリスクを表現するためにExposureを1としたものです。 " + 
                                    "請求件数は、請求頻度モデルを用いて予測され、出力は、与えられたパラメータに対する1年間の請求件数の予想値です。 "+
                                    "請求額は、請求の重大度モデルを使用しており、所定の保険金支払額の予測を表しています。 " +
                                    "複合純保険料予測は、上記の2つの予測である請求頻度と請求額の積です。これは、1年間におけるお客様のリスクの合計を貨幣価値で表したものです。"+
                                    "Tweedie純保険料予測は、複合純保険料と同じ量を計算しますが、頻度と重大性を別々にではなく、直接モデル化することを目指した純保険料Tweedieモデルを使用します。", 
                                    style={'font-size': '12px'})], 
                            style={'display': 'block', 
                                   'background': '#e6eef2',
                                    'padding': '20px',
                                    'border-radius': '5px',
                                    'color': '#31708f'})])
        ])
    ]), style={'font-family': 'Helvetica Neue'}
)

