# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
import json
import re

df = dataiku.Dataset("requests_processed_langgraph").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Agent Evaluation Method #1: LLM-as-a-judge with `langchain`
# Required inputs:
# - request
# - reply
# - intermediate steps

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from dataiku.langchain.dku_llm import DKUChatLLM
from langchain.evaluation import load_evaluator
from langchain_core.agents import AgentAction

LLM_ID = dataiku.get_custom_variables()["LLM_id"]
llm = DKUChatLLM(
    llm_id=LLM_ID,
    temperature=0
)
evaluator = load_evaluator("trajectory", llm=llm)

pattern = re.compile("[0-9]*. (.*)\((.*)\) --> (.*)")

def deserialize_intermediate_steps(trajectory):
    """
    Deserialize the string representing the intermediate steps of an agent trajectory.
    """
    result = []
    for action in trajectory.split("\n"):
        m = pattern.match(action)
        result.append(
            (
                AgentAction(tool=m.group(1), tool_input=json.loads(m.group(2)), log=""),
                m.group(3)
            )
        )
    return result

for i in df.index:
    result = evaluator.evaluate_agent_trajectory(
        prediction=df.at[i, "draft_reply"],
        input=df.at[i, "request"],
        agent_trajectory=deserialize_intermediate_steps(df.at[i, "intermediate_steps"]),
    )
    df.at[i, "langchain/score"] = result["score"]
    df.at[i, "langchain/justification"] = result["reasoning"]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Agent Evaluation Method #2: conformance checking with `pm4py`
# Required inputs:
# - model for the ground truth trajectory
# - intermediate steps

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import yaml
import pandas as pd

import pm4py
from pm4py.objects.process_tree.obj import Operator, ProcessTree

str2operator = {
    "SEQUENCE": Operator.SEQUENCE,
    "XOR": Operator.XOR,
    "PARALLEL": Operator.PARALLEL,
    "LOOP": Operator.LOOP,
    "OR": Operator.OR,
    "INTERLEAVING": Operator.INTERLEAVING,
    "PARTIALORDER": Operator.PARTIALORDER,
}

def create_tree(d):
    """
    Convert a dictionary into a `pm4py` Process Tree.
    The dictionary should have a tree structure:
      - with an "operator" (cf. `str2operator` for a list of options) in the internal nodes;
      - an action in the leaf nodes.
    Cf. the examples in the `requests` dataset.
    """
    if type(d) == str:
        # The actions in the input YAML string should have `: ` replaced with `=` to avoid YAML format errors.
        # Cf. the examples in the `requests` dataset.
        return ProcessTree(label=d.replace("=", ": "))
    if type(d) == list:
        return ProcessTree(
            operator=Operator.SEQUENCE,
            children=[create_tree(x) for x in d]
        )
    for k in d:
        return ProcessTree(
            operator=str2operator[k],
            children=[create_tree(x) for x in d[k]]
        )

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in df.index:
    intermediate_steps = deserialize_intermediate_steps(df.at[i, "intermediate_steps"])
    trajectory = []
    for action in intermediate_steps:
        action_name = action[0].tool
        action_input = action[0].tool_input
        if type(action_input) == dict:
            action_input = json.dumps(action_input)
        trajectory.append(f"{action_name}({action_input})")

    tree = create_tree(yaml.safe_load(df.at[i, "reference_trajectory"]))

    log_df = pd.DataFrame.from_dict(
        {
            "case_id": [0]*len(trajectory),
            "activity": trajectory,
            "ts": [pd.Timestamp(2024, 1, 1, 12)]*len(trajectory),
        }
    )
    log = pm4py.convert_to_event_log(
        pm4py.format_dataframe(
            log_df,
            case_id='case_id',
            activity_key='activity',
            timestamp_key='ts'
        )
    )
    aligned_traces = pm4py.conformance_diagnostics_alignments(log, *pm4py.convert_to_petri_net(tree))

    df.at[i, "trajectory_fit"] = aligned_traces[0]["fitness"]
    alignment = aligned_traces[0]["alignment"]
    alignment = "\n".join([f"{x[0], x[1]}" for x in alignment if x not in [('>>', None), (None, '>>')]])
    df.at[i, "trajectory_alignment"] = alignment

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Agent Evaluation Method #3: LLM-as-a-judge with `mlflow`
# Required inputs:
# - request
# - reply
# - ground truth reply

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import mlflow
from mlflow.metrics import genai
import os

auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "openai_key":
        os.environ["OPENAI_API_KEY"] = secret["value"]
        break

with mlflow.start_run() as run:
    results = mlflow.evaluate(
        data=df,
        targets="reference_answer",
        predictions="draft_reply",
        extra_metrics=[
            genai.answer_correctness(),
        ],
        evaluators="default",
        evaluator_config={'col_mapping': {'inputs': 'request'}}
    )

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
df = results.tables['eval_results_table']
df = df[[
    'customer_id',
    'request',
    'reference_answer',
    'reference_trajectory',
    'draft_reply',
    'actions',
    'intermediate_steps',
    'langchain/score',
    'langchain/justification',
    'trajectory_fit',
    'trajectory_alignment',
    'answer_correctness/v1/score',
    'answer_correctness/v1/justification'
]]
df = df.rename(
    {
        "answer_correctness/v1/justification": "mlflow/justification",
        "answer_correctness/v1/score": "mlflow/score"
    },
    axis=1
)
dataiku.Dataset("agent_answers_evaluated").write_with_schema(df)