# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import os
import re
import random
import dataiku

from functools import partial
from langchain.chat_models import ChatOpenAI
from langchain.tools import StructuredTool
from langchain.agents import initialize_agent, AgentType

from dataiku2tools import (
    create_tool_from_datasets,
    create_tool_from_python_endpoint,
    create_tool_from_model_endpoint,
    create_tool_from_managed_folder
)

from fictitious_tools import (
    get_customer_id,
    get_details,
    reset_password,
    cancel_appointment,
    schedule_local_intervention,
    schedule_distant_intervention,
    sign_up_to_option,
    cancel_option,
    cancel_phone_subscription,
    subscribe_to_loyalty_program,
    run_diagnostics,
    get_product_info
)

auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "openai_key":
        os.environ["OPENAI_API_KEY"] = secret["value"]
    elif secret["key"] == "api_key":
        API_KEY = secret["value"]

LLMS = {
    k: ChatOpenAI(temperature=0, model_name=k)
    for k in ["gpt-3.5-turbo", "gpt-4"]
}

MODEL_ENDPOINT_DESCRIPTION = run_diagnostics.__doc__.replace("    ", "").strip()

MODEL_ENDPOINT_DESCRIPTION_WITH_ID = "\n".join(MODEL_ENDPOINT_DESCRIPTION.split("\n")[:-1]) + "\nThe input should be the empty string"

PYTHON_ENDPOINT_DESCRIPTION = get_product_info.__doc__.replace("    ", "").strip()

CHUNKS_FOLDER_ID = "DQwRivV3"

df = dataiku.Dataset("customer_requests").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
functions = [
    get_details,
    reset_password,
    cancel_appointment,
    schedule_local_intervention,
    schedule_distant_intervention,
    sign_up_to_option,
    cancel_option,
    cancel_phone_subscription,
    subscribe_to_loyalty_program
]

def get_partial_functions(customer_id):
    """
    Provide partial functions corresponding to a specific customer.
    These partial functions should be used when the customer is identified.
    """
    return [
        lambda s: get_details(customer_id),
        lambda s: reset_password(customer_id),
        lambda s: cancel_appointment(customer_id),
        lambda s: schedule_local_intervention(customer_id),
        lambda s: schedule_distant_intervention(customer_id),
        lambda s: sign_up_to_option(f"{customer_id},{s}"),
        lambda s: cancel_option(f"{customer_id},{s}"),
        lambda s: cancel_phone_subscription(customer_id),
        lambda s: subscribe_to_loyalty_program(customer_id)
    ]

def get_tools(customer_id=None):
    """
    Provide either the general tools or the customer-specific tools.
    """
    tools = []

    if customer_id is None:
        for f in functions:
            tools.append(StructuredTool.from_function(f))
    else:
        partial_functions = get_partial_functions(str(customer_id))
        for i in range(len(partial_functions)):
            description = "\n".join(functions[i].__doc__.strip().split("\n    ")[:-1])
            # If partial functions are used, their description should be amended
            if functions[i].__name__ in ["sign_up_to_option", "cancel_option"]:
                description += "\nThe action input should be either 'TV' or 'Premium'"
            else:
                description += "\nThe action input should be the empty string"
            tools.append(StructuredTool.from_function(
                partial_functions[i],
                name=functions[i].__name__,
                description=description
            ))

    def post_process(x):
        if str(x) == "0":
            return f"The customer can directly solve the connection problem by following the instructions provided in the FAQ: www.telco-operator/faq/{random.randint(1, 100)}. An intervention of a technician is not warranted."
        elif str(x) == "1":
            return "A technician must visit the customer to solve the connection problem. An appointment needs to be made."
        elif str(x) == "2":
            return "A technician must discuss with the customer over the phone to solve the connection problem. An appointment needs to be made."
        return "Unexpected output"

    project = dataiku.api_client().get_default_project()

    url1 = project.get_variables()["standard"]["get_product_info_url"]
    if len(url1) > 0:
        tools.append(
            create_tool_from_python_endpoint(
                url1,
                PYTHON_ENDPOINT_DESCRIPTION,
                api_key=API_KEY
            )
        )
    else:
        tools.append(StructuredTool.from_function(get_product_info))

    url2 = project.get_variables()["standard"]["run_diagnostic_url"]
    if len(url2) > 0:
        if customer_id is None:
            tools.append(
                create_tool_from_model_endpoint(
                    url2,
                    MODEL_ENDPOINT_DESCRIPTION,
                    api_key=API_KEY,
                    post_process=post_process,
                    key="client_id"
                )
            )
        else:
            tools.append(
                create_tool_from_model_endpoint(
                    url2,
                    MODEL_ENDPOINT_DESCRIPTION_WITH_ID,
                    api_key=API_KEY,
                    post_process=post_process,
                    key="client_id",
                    value=str(customer_id)
                )
            )
    else:
        if customer_id is None:
            tools.append(StructuredTool.from_function(run_diagnostics))
        else:
            tools.append(StructuredTool.from_function(
                lambda s: run_diagnostics(customer_id),
                name="run_diagnostics",
                description=run_diagnostics.__doc__ + "\nThe action input should be the empty string"
            ))

    if customer_id is None:
        datasets_restrictions = [
            {"dataset_name": "customers_info_sql", "key": "id", "value": str(customer_id)},
            {"dataset_name": "customers_invoices_sql", "key": "client_id", "value": str(customer_id)}
        ]
    else:
        datasets_restrictions = []
    tools.append(create_tool_from_datasets("used_for_text2sql", datasets_restrictions=datasets_restrictions, llm=LLMS["gpt-3.5-turbo"]))

    additional_instructions = """
    Your answer should describe the applicable procedure in the manual for the specific need expressed by the customer.
    Focus on answering the specific question asked by the customer.
    """
    tools.append(create_tool_from_managed_folder(CHUNKS_FOLDER_ID, additional_instructions=additional_instructions, filters={"docs": "manual"}))

    return tools

def process_request(request, customer_id=None, datasets_restrictions=[], llm=ChatOpenAI(temperature=0), additional_instructions="", separator_for_additional_instructions="Begin!"):
    """
    Handle the customer's request.
    If the customer id is specified, only the corresponding customer can be affected by the agent.
    """

    tools = get_tools(customer_id=customer_id)

    tool_names = [tool.name for tool in tools]
    agent = initialize_agent(
        tools,
        llm,
        agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True,
        handle_parsing_errors=True,
        return_intermediate_steps=True,
    )

    if len(additional_instructions) > 0:
        splitted = agent.agent.llm_chain.prompt.template.split(separator_for_additional_instructions)
        splitted[0] += additional_instructions
        agent.agent.llm_chain.prompt.template = separator_for_additional_instructions.join(splitted)

    if customer_id is not None:
        try:
            customer_id = int(customer_id)
        except Exception:
            None
        identity = get_details(str(customer_id)).split(",")[0] + f" (id: {customer_id})"
        request = f"Message received from {identity}: '{request}'"
    else:
        request = f"Request: '{request}'"

    result = agent(request)


    actions = ""
    actions_details = ""

    # Lists the actions and their details in two strings returned along with the answer
    for j in range(len(result["intermediate_steps"])):
        tool = result["intermediate_steps"][j][0].tool
        tool_input = result["intermediate_steps"][j][0].tool_input
        thought = result["intermediate_steps"][j][0].log
        observation = result["intermediate_steps"][j][1]
        actions += f"{j+1}. {tool}({tool_input})\n"
        actions_details += f"{j+1}. {thought}\nObservation: {observation}\n\n"

    return (
        result["output"],
        actions[:-1] if len(actions) > 0 else actions,
        actions_details[:-1] if len(actions_details) > 0 else actions_details
    )

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in range(len(df)):

    customer_id = df.iloc[i].customer_id
    reply, actions, actions_details = {}, {}, {}

    for model in LLMS:
        llm = LLMS[model]
        try:
            reply[model], actions[model], actions_details[model] = process_request(
                df.iloc[i].request,
                llm=llm,
                customer_id=customer_id
            )
        except Exception as e:
            reply[model], actions[model], actions_details[model] = str(e), str(e), str(e)

    df.at[i, "draft_reply"] = reply["gpt-3.5-turbo"]
    df.at[i, "actions"] = actions["gpt-3.5-turbo"]
    df.at[i, "actions_details"] = actions_details["gpt-3.5-turbo"]
    df.at[i, "draft_reply_gpt4"] = reply["gpt-4"]
    df.at[i, "actions_gpt4"] = actions["gpt-4"]
    df.at[i, "actions_details_gpt4"] = actions_details["gpt-4"]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("answers").write_with_schema(df)