# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import os
import re
import dataiku
import json
import logging
import functools
import time
from datetime import datetime
from openai import OpenAI
import tiktoken

from project_utils import YouSearchWrapper, BraveSearchWrapper, filter_urls, with_timeout, index_urls, clean_html

df = dataiku.Dataset("questions").get_dataframe()
mlflow_folder = dataiku.Folder("gnnCu5cW")
project = dataiku.api_client().get_default_project()

auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "BRAVE_API_KEY":
        os.environ["BRAVE_API_KEY"] = secret["value"]
    elif secret["key"] == "YDC_API_KEY":
        os.environ["YDC_API_KEY"] = secret["value"]
    elif secret["key"] == "openai_key":
        openai_client = OpenAI(api_key=secret["value"])

SEARCH_ENGINE = "You" # "You" or "Brave"
LLM = "gpt-3.5-turbo"
NUM_SEARCH_RESULTS = 10
NUM_CHUNKS = 10
MAX_TOKENS_CONTEXT = 3500
encoding = tiktoken.encoding_for_model(LLM)

if SEARCH_ENGINE == "Brave":
    search_engine = BraveSearchWrapper(NUM_SEARCH_RESULTS)
else:
    search_engine = YouSearchWrapper(NUM_SEARCH_RESULTS)

def escape_markdown(text):
    """
    Escape Markdown content to properly display sources.
    """
    return text.replace('\\*', '*').replace('\\`', '`').replace('\\_', '_')\
        .replace('\\~', '~').replace('\\>', '>').replace('\\[', '[')\
        .replace('\\]', ']').replace('\\(', '(').replace('\\)', ')')\
        .replace('*', '\\*').replace('`', '\\`').replace('_', '\\_')\
        .replace('~', '\\~').replace('>', '\\>').replace('[', '\\[')\
        .replace(']', '\\]').replace('(', '\\(').replace(')', '\\)')

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
SYSTEM_PROMPT = """You are a helpful assistant that answers questions.
You can use an Internet search engine if needed.
When asked a question, you either directly reply if you know the answer or provide one or several Internet search queries that can help you get the right answer."""

@functools.lru_cache()
def get_answer_or_search_queries(question):
    function = [
        {
            "name": "answer_or_get_search_query",
            "description": "Display the answer or provide one or several search queries",
            "parameters": {
                "type": "object",
                "properties": {
                    "answer": {
                        "type": "string",
                        "description": "Answer to the question when it is known. If unsure, this answer should be the empty string and a list of Internet search queries useful to find the answer should be provided",
                    },
                    "search_queries": {
                        "type": "array",
                        "minItems": 0,
                        "maxItems": 3,
                        "items": {
                            "type": "string"
                        },
                        "description": "List of Internet search queries useful to find the answer to the question, if it is unknown or uncertain. Return an empty list if the answer is already known",
                    },
                },
                "required": ["answer", "search_queries"],
            },
        }
    ]

    messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT
        },
        {
            "role": "system",
            "content": f"The current date and time are: {str(datetime.now())}"
        },
        {
            "role": "user",
            "content": question
        },
    ]

    response = openai_client.chat.completions.create(
        model=LLM,
        functions=function,
        function_call={"name": "answer_or_get_search_query"},
        messages=messages,
        temperature=0
    )
    response = json.loads(response.choices[0].message.function_call.arguments)
    answer = response["answer"] if "answer" in response else ""
    search_queries = response["search_queries"] if "search_queries" in response else None
    return answer, search_queries


SYSTEM_PROMPT_RAG = """You are a helpful assistant that answers questions based on facts retrieved from the Internet.
You justify your answers by referring to the relevant facts.
Don't include the references to these facts (e.g. "REF1", "REF2"...) directly in your answer.
Include only one of these facts except if several facts are needed to reach your conclusion.
If you are unsure, answer: "I don't know"."""

def format_sources(extracts):
    """
    Format the context chunks in Markdown.
    """
    results = []
    for x in extracts:
        title = escape_markdown(x['title'])
        link = escape_markdown(x['link'])
        snippet = escape_markdown(x['snippet'])
        results.append(f"[{title}]({link}): {snippet}")
    return results

def get_answer_function(source_ids):
    """
    Generate the function specs for the OpenAI "function calling" feature.
    """
    return [
        {
            "name": "display_answer",
            "description": "Display the answer and the relevant sources",
            "parameters": {
                "type": "object",
                "properties": {
                    "answer_found": {
                        "type": "boolean",
                        "description": "Answer found. Whether an answer has been found given the facts provided.",
                    },
                    "sources": {
                        "type": "array",
                        "minItems": 0,
                        "maxItems": 2,
                        "items": {
                            "type": "string",
                            "enum": source_ids,
                        },
                        "description": "Sources supporting the answer. Sources are denoted by REF1, REF2... Mention at most 3 sources. Do not include redundant sources",
                    },
                    "answer": {
                        "type": "string",
                        "description": "Answer. Don't include the references to these facts (e.g. 'REF1', 'REF2'...) directly in your answer."
                    },
                },
                "required": ["answer", "sources", "answer_found"],
            },
        }
    ]

def get_answer_with_sources(question, context):
    """
    Answer the question using the context information and providing sources.
    """
    num_tokens = 0
    for i in range(len(context)):
        num_tokens += len(encoding.encode(context[i]))
        if num_tokens > MAX_TOKENS_CONTEXT:
            break
    else:
        i += 1
    context = context[:i]
    formatted_context = "\n\n"+"\n\n".join([f"REF{k+1}. {context[k]}" for k in range(len(context))])+"\n\n"
    user_prompt = f"Based on the following facts:{formatted_context}...Answer this question: {question}"

    messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT_RAG
        },
        {
            "role": "system",
            "content": f"The current date and time are: {str(datetime.now())}"
        },
        {
            "role": "user",
            "content": user_prompt
        },
    ]

    response = openai_client.chat.completions.create(
        model=LLM,
        functions=get_answer_function([f"REF{k+1}" for k in range(len(context))]),
        function_call={"name": "display_answer"},
        messages=messages,
        temperature=0
    )

    response = json.loads(response.choices[0].message.function_call.arguments)
    answer = response["answer"]

    if len(response["sources"]) > 0:
        pattern = r"\D*(\d+)\D*"
        sources = []
        for x in response["sources"]:
            match = re.search(pattern, x)
            if match:
                x = int(match.group(1))
                sources.append(x - 1)

        formatted_sources = []
        for source in sources:
            if source < len(context) and source >= 0:
                formatted_sources.append(context[source])
        formatted_sources = "\n\n".join(formatted_sources)

        answer += f"\n\nSources:\n\n{formatted_sources}"

    return answer, response["answer_found"]

def basic_web_research(question, queries=None, domain=None):
    """
    Answer a question with the results of a web search.
    """
    if domain is not None and len(domain) > 0:
        domain_suffix = f" site:{domain}"
    else:
        domain_suffix = ""
    if queries is None or len(queries) == 0:
        search_results = search_engine.results(question + domain_suffix)
    else:
        search_results = []
        for query in queries:
            search_results += search_engine.results(query + domain_suffix)
    formatted_search_results = format_sources(search_results)
    logging.info("Web search: " + "\n".join(formatted_search_results))
    return get_answer_with_sources(
        question,
        formatted_search_results
    ), formatted_search_results, [x["link"] for x in search_results]

def format_chunks(chunks):
    """
    Format the context chunks retrieved with a semantic search in Markdown.
    """
    results = []
    for x in chunks:
        source = escape_markdown(x.metadata['source'])
        page_content = escape_markdown(x.page_content)
        results.append(f"{source}: {page_content}")
    return results

def advanced_web_research(question, urls, queries=None):
    """
    Answer a question with the content of some web pages.
    """
    index = index_urls(filter_urls(urls))
    if len(urls) == 0 or index is None:
        return "Unknown", False
    if queries is None or len(queries) == 0:
        chunks = index.similarity_search(question, k=NUM_CHUNKS)
    else:
        chunks = []
        for query in queries:
            chunks += index.similarity_search(query, k=NUM_CHUNKS)
    formatted_chunks = format_chunks(chunks)
    return get_answer_with_sources(
        question,
        formatted_chunks
    ), formatted_chunks

def get_answer(question, domain=None, mlflow_handle=None):
    """
    Answer the question directly, or with search results only or with the content of the web pages returned as search results.
    """
    start = time.time()

    answer, search_queries = get_answer_or_search_queries(question)
    if mlflow_handle is not None:
        mlflow_handle.log_dict(
            {
                "question": question,
                "answer": answer,
                "search_queries": search_queries
            },
            artifact_file="LLM_only.json"
        )
    if len(answer) > 0:
        logging.info("Answer found without web search:\n" + answer)
        if mlflow_handle is not None:
            mlflow_handle.log_metric("delay", time.time() - start)
            mlflow_handle.log_metric("steps_needed", 1)
        return answer, True

    logging.info(f"Web search needed. Search queries: {', '.join(search_queries)}")
    (answer, answer_found), context, urls = basic_web_research(question, domain=domain, queries=search_queries)
    if mlflow_handle is not None:
        mlflow_handle.log_dict(
            {
                "context": context,
                "answer": answer,
                "answer_found": answer_found,
                "urls": urls
            },
            artifact_file="search.json"
        )

    if answer_found:
        logging.info("Answer based on the search results snippets:\n" + answer)
        if mlflow_handle is not None:
            mlflow_handle.log_metric("delay", time.time() - start)
            mlflow_handle.log_metric("steps_needed", 2)
        return answer, True
    else:
        logging.info("Answer not found only with search results:\n" + f"{answer} / {', '.join(urls)}")

    (answer, answer_found), context = advanced_web_research(question, urls, queries=search_queries)
    if mlflow_handle is not None:
        mlflow_handle.log_dict(
            {
                "context": context,
                "answer": answer,
                "answer_found": answer_found
            },
            artifact_file="search_and_web_scraping.json"
        )
    if answer_found:
        logging.info("Answer based on web pages:\n" + answer)
    else:
        logging.info(f"Answer not found: {answer}")

    if mlflow_handle is not None:
        mlflow_handle.log_metric("delay", time.time() - start)
        mlflow_handle.log_metric("steps_needed", 3)
    return answer, answer_found

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
with project.setup_mlflow(managed_folder=mlflow_folder) as mlflow_handle:
    mlflow_handle.set_experiment("cascade RAG")
    for i in df.index:
        with mlflow_handle.start_run(run_name=f"cascade_RAG_{i}"):
            answer, _ = get_answer(df.loc[i, "question"], mlflow_handle=mlflow_handle)
            df.at[i, "answer"] = answer

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("answers_cascade_RAG").write_with_schema(df)