import time
import os
import re
import dataiku
import openai
import json
import logging
import functools
from datetime import datetime

from project_utils import YouSearchWrapper, BraveSearchWrapper, filter_urls, with_timeout, index_urls, clean_html

from openai import OpenAI
import tiktoken

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

LOG_ALL_ANSWERS = False # Whether all answers or only answers with user feedback should be logged
WEBAPP_NAME = "qanda_web" # Name of the app (logged when a conversation is flagged)
VERSION = "1.0" # Version of the app (logged when a conversation is flagged)

# Folder to log answers and positive/negative reactions
answers_folder = dataiku.Folder("JQJOdB0K")

LLM = "gpt-3.5-turbo"
NUM_SEARCH_RESULTS = 5
MAX_TOKENS_CONTEXT = 3500
encoding = tiktoken.encoding_for_model(LLM)
SEARCH_ENGINE = "You" # Replace with "Brave" to use the Brave Search API

auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
keys_provided = 0
for secret in auth_info["secrets"]:
    if secret["key"] == "YDC_API_KEY":
        os.environ["YDC_API_KEY"] = secret["value"]
        keys_provided += 1
    if secret["key"] == "BRAVE_API_KEY":
        os.environ["BRAVE_API_KEY"] = secret["value"]
        keys_provided += 1
    elif secret["key"] == "openai_key":
        openai_client = OpenAI(api_key=secret["value"])
        keys_provided += 2

ERROR_MESSAGE_MISSING_KEY = """
OpenAI key or You.com API key or Brave Search API key missing. Cf. this project's wiki

**Please note that the question answering web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing the keys as user secrets**.
"""

if SEARCH_ENGINE == "Brave" and "BRAVE_API_KEY" in os.environ:
    search_engine = BraveSearchWrapper(NUM_SEARCH_RESULTS)
elif SEARCH_ENGINE == "You" and "YDC_API_KEY" in os.environ:
    search_engine = YouSearchWrapper(NUM_SEARCH_RESULTS)

def escape_markdown(text):
    return text.replace('\\*', '*').replace('\\`', '`').replace('\\_', '_')\
        .replace('\\~', '~').replace('\\>', '>').replace('\\[', '[')\
        .replace('\\]', ']').replace('\\(', '(').replace('\\)', ')')\
        .replace('*', '\\*').replace('`', '\\`').replace('_', '\\_')\
        .replace('~', '\\~').replace('>', '\\>').replace('[', '\\[')\
        .replace(']', '\\]').replace('(', '\\(').replace(')', '\\)')

SYSTEM_PROMPT_RAG = """You are a helpful assistant that answers questions based on facts retrieved from the Internet.
You justify your answers by referring to the relevant facts.
Don't include the references to these facts (e.g. "REF1", "REF2"...) directly in your answer.
Include only one of these facts except if several facts are needed to reach your conclusion.
If you are unsure, answer: "I don't know"."""

def format_sources(extracts):
    """
    Format the context chunks in Markdown.
    """
    results = []
    for x in extracts:
        title = escape_markdown(x['title'])
        link = escape_markdown(x['link'])
        snippet = escape_markdown(x['snippet'])
        results.append(f"[{title}]({link}): {snippet}")
    return results

def get_answer_function(source_ids):
    """
    Generate the function specs for the OpenAI "function calling" feature.
    """
    return [
        {
            "name": "display_answer",
            "description": "Display the answer and the relevant sources",
            "parameters": {
                "type": "object",
                "properties": {
                    "answer_found": {
                        "type": "boolean",
                        "description": "Answer found. Whether an answer has been found given the facts provided.",
                    },
                    "sources": {
                        "type": "array",
                        "minItems": 0,
                        "maxItems": 2,
                        "items": {
                            "type": "string",
                            "enum": source_ids,
                        },
                        "description": "Sources supporting the answer. Sources are denoted by REF1, REF2... Mention at most 3 sources. Do not include redundant sources",
                    },
                    "answer": {
                        "type": "string",
                        "description": "Answer. Don't include the references to these facts (e.g. 'REF1', 'REF2'...) directly in your answer."
                    },
                },
                "required": ["answer", "sources", "answer_found"],
            },
        }
    ]

def get_answer_with_sources(question, context):
    """
    Answer the question using the context information and providing sources.
    """
    num_tokens = 0
    for i in range(len(context)):
        num_tokens += len(encoding.encode(context[i]))
        if num_tokens > MAX_TOKENS_CONTEXT:
            break
    else:
        i += 1
    context = context[:i]
    formatted_context = "\n\n"+"\n\n".join([f"REF{k+1}. {context[k]}" for k in range(len(context))])+"\n\n"
    user_prompt = f"Based on the following facts:{formatted_context}...Answer this question: {question}"
    
    messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT_RAG
        },
        {
            "role": "system",
            "content": f"The current date and time are: {str(datetime.now())}"
        },        
        {
            "role": "user",
            "content": user_prompt
        },
    ]
    
    response = openai_client.chat.completions.create(
        model=LLM,
        functions=get_answer_function([f"REF{k+1}" for k in range(len(context))]),
        function_call={"name": "display_answer"},
        messages=messages,
        temperature=0
    )
    
    response = json.loads(response.choices[0].message.function_call.arguments)
    answer = response["answer"]
    
    if len(response["sources"]) > 0:
        pattern = r"\D*(\d+)\D*"
        sources = []
        for x in response["sources"]:
            match = re.search(pattern, x)
            if match:
                x = int(match.group(1))
                sources.append(x - 1)
    
        formatted_sources = []
        for source in sources:
            if source < len(context) and source >= 0:
                formatted_sources.append(context[source])    
        formatted_sources = "\n\n".join(formatted_sources)
        
        answer += f"\n\nSources:\n\n{formatted_sources}"

    return answer, response["answer_found"]

def basic_web_research(question, queries=None, domain=None):
    """
    Answer the question with one or several web searches.
    """
    if domain is not None and len(domain) > 0:
        domain_suffix = f" site:{domain}"
    else:
        domain_suffix = ""
    if queries is None or len(queries) == 0:
        search_results = search_engine.results(question + domain_suffix)
    else:
        search_results = []
        for query in queries:
            search_results += search_engine.results(query + domain_suffix)
    formatted_search_results = format_sources(search_results)
    logging.info("Web search: " + "\n".join(formatted_search_results))
    return get_answer_with_sources(
        question,
        formatted_search_results
    ), [x["link"] for x in search_results]

@functools.lru_cache()
def get_answer(question, domain=None):
    """
    Answer the question.
    """
    (answer, answer_found), urls = basic_web_research(question, domain=domain, queries=[question])
    if answer_found:
        logging.info("Answer based on the search results snippets:\n" + answer)
        return answer
    else:
        logging.info("Answer not found only with search results:\n" + f"{answer} / {', '.join(urls)}")
    return answer

# Layout

STYLE_ANSWER = {
    "margin-top": "20px",
    "align-items": "flex-start",
    "display": "flex",
    "height": "auto"
}

STYLE_BUTTON = {
    "width": "20px",  
    "text-align": "center",
    "margin": "0px 5px",
}

STYLE_FEEDBACK = {
    "margin-left": "10px",
    "display": "none"
}

ok_icon_fill = html.Span(html.I(className="bi bi-emoji-smile-fill"))
nok_icon_fill = html.Span(html.I(className="bi bi-emoji-frown-fill"))
ok_icon = html.Span(html.I(className="bi bi-emoji-smile"))
nok_icon = html.Span(html.I(className="bi bi-emoji-frown"))

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0, placeholder="Ask a question"),
        dbc.Button(send_icon, id='send-btn', title='Get an answer')
    ],
    style = {"margin-top": "20px"}
)

app.title = "Question answering"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        html.H4(
            "Question answering with Internet search",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        question_bar,
        dbc.Input(
            id='domain',
            value='',
            type='text',
            minLength=0,
            placeholder="Optional domain name to restrict the internet search (e.g. wikipedia.org)",
            style = {"margin-top": "10px"}
        ),
        html.Div(
            [
                dbc.Spinner(
                    dcc.Markdown(
                        id='answer',
                        link_target="_blank",
                        style={"min-width": "100px"}
                    ),
                    color="primary"
                ),
                html.Div(
                    [
                        html.A(ok_icon, id="link_ok", href="#", style=STYLE_BUTTON),
                        html.A(nok_icon, id="link_nok", href="#", style=STYLE_BUTTON)   
                    ],
                    id="feedback_buttons",
                    style=STYLE_FEEDBACK
                )
            ],
            style=STYLE_ANSWER
        ),
        html.Div(id='debug'),
        dcc.Store(id='feedback', data=2, storage_type='memory'),
        dcc.Store(id='question', storage_type='memory'),
        dcc.Store(id='question_id', storage_type='memory'),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

@app.callback(
    Output('answer', 'children'),
    Output('feedback_buttons', 'style'),
    Output('question', 'data'),
    Output('question_id', 'data'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    Input('domain', 'n_submit'),
    State('query', 'value'),
    State('domain', 'value'),
)
def answer_question(n_clicks, n_submit, n_submit2, query, domain):
    """
    Display the answer.
    """
    if len(query) == 0:
        return "", STYLE_FEEDBACK, query, 0
    start = time.time()
    if keys_provided < 3:
        return ERROR_MESSAGE_MISSING_KEY, STYLE_ANSWER, query, 0
    style = dict(STYLE_FEEDBACK)
    style["display"] = "flex"
    answer = get_answer(query, domain=("" if len(domain) == 0 else domain))
    answer = f"{answer}\n\n{(time.time()-start):.1f} seconds"
    return answer, style, query, hash(str(start) + query)

@app.callback(
    Output('debug', 'children'),
    Input('answer', 'children'),
    Input('link_ok', 'children'),
    Input('link_nok', 'children'),
    State('query', 'value'),
    State('domain', 'value'),
    State('question_id', 'data'),
    State('feedback', 'data'),
)
def log_answer(answer, ok, nok, query, domain, question_id, feedback):
    """
    Log the question and the answer.
    """
    if len(answer) > 0:
        path = f"/{str(question_id)}.json"
        if LOG_ALL_ANSWERS or feedback in [1, -1]:
            with answers_folder.get_writer(path) as w:
                w.write(bytes(json.dumps({
                    "question": query,
                    "domain": domain,
                    "answer": answer,
                    "feedback": 0 if feedback == 2 else feedback,
                    "timestamp": str(datetime.now()),
                    "webapp": WEBAPP_NAME,
                    "version": VERSION,
                    "search engine": SEARCH_ENGINE
                }), "utf-8"))
        else:
            if path in answers_folder.list_paths_in_partition():
                answers_folder.delete_path(path)
    return ""

@app.callback(
    Output('link_ok', 'children'),
    Output('link_nok', 'children'),
    Input('feedback', 'data')
)
def update_icons(value):
    """
    Update the feedback icons when the user likes or dislikes an answer.
    """
    ok = ok_icon_fill if value is not None and value == 1 else ok_icon
    nok = nok_icon_fill if value is not None and value == -1 else nok_icon
    return ok, nok

@app.callback(
    Output('feedback', 'data'),
    Input('link_ok', 'n_clicks'),
    Input('link_nok', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    Input('domain', 'n_submit'),
    State('feedback', 'data'),
    State('query', 'value'),
    State('question', 'data'),
)
def provide_feedback(ts_ok, ts_nok, click, submit, submit2, value, question, previous_question):
    """
    Record the feedback of the user
    """
    triggered = dash.ctx.triggered_id
    if triggered == "link_ok":
        return 1 if value != 1 else 0
    elif triggered == "link_nok":
        return -1 if value != -1 else 0
    return value if question == previous_question else 2