# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
from dataikuapi.dss.modelevaluationstore import DSSModelEvaluationStore
import json
from datetime import datetime
import jsonschema
import numpy as np

df1 = dataiku.Dataset("sft_predictions_gpt_4o").get_dataframe()
df2 = dataiku.Dataset("sft_predictions_finetuned").get_dataframe()
df3 = dataiku.Dataset("sft_predictions_finetuned_gpt-4o-mini").get_dataframe()

EVALUATION_STORE_ID = "mYObd5To"

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def jaccard_similarity(s1, s2):
    """
    Score the similarity between two dictionaries (0 = no key-value pairs in common, 1 = identical).
    """
    score = 0
    try:
        d1 = json.loads(s1)
        d2 = json.loads(s2)
        if type(d1) != dict or type(d2) != dict:
            return 0
        keys = set(d1.keys()).union(set(d2.keys()))
        for k in keys:
            if k in d1 and k in d2:
                if type(d1[k]) == list:
                    if type(d2[k]) == list:
                        set1 = set(d1[k])
                        set2 = set(d2[k])
                        intersection = len(set1.intersection(set2))
                        union = len(set1.union(set2))
                        if union > 0:
                            score += intersection/union
                        else:
                            score += 1
                elif d1[k] == d2[k]:
                    score += 1
        return score/len(keys)
    except json.JSONDecodeError:
        return 0

SCHEMA = {
    "type": "object",
    "properties": {
        "text_type": {
            "type": "string",
            "enum": [
                "verify_attribute",
                "suggest",
                "request_attribute",
                "recommend",
                "request_explanation",
                "inform",
                "confirm",
                "request",
                "give_opinion",
            ],
            "description": "The type of text, indicating the purpose or intent of the message.",
        },
        "name": {"type": "string", "description": "The name of the video game."},
        "release_year": {
            "type": "integer",
            "description": "The year when the video game was released.",
        },
        "esrb": {
            "type": "string",
            "enum": [
                "E (for Everyone)",
                "T (for Teen)",
                "E 10+ (for Everyone 10 and Older)",
                "M (for Mature)",
            ],
            "description": "The ESRB (Entertainment Software Rating Board) rating of the video game.",
        },
        "genres": {
            "type": "array",
            "items": {
                "type": "string",
                "enum": [
                    "driving/racing",
                    "pinball",
                    "shooter",
                    "turn-based strategy",
                    "adventure",
                    "fighting",
                    "tactical",
                    "arcade",
                    "real-time strategy",
                    "hack-and-slash",
                    "text adventure",
                    "strategy",
                    "sport",
                    "MMORPG",
                    "trivia/board game",
                    "simulation",
                    "role-playing",
                    "indie",
                    "point-and-click",
                    "action-adventure",
                    "puzzle",
                    "music",
                    "vehicular combat",
                    "platformer",
                    "action",
                ],
                "description": "The genres associated with the video game.",
            },
        },
        "platforms": {
            "type": "array",
            "items": {
                "type": "string",
                "enum": ["PlayStation", "Xbox", "PC", "Nintendo", "Nintendo Switch"],
                "description": "The gaming platforms on which the video game is available.",
            },
        },
        "available_on_steam": {
            "type": "string",
            "enum": ["yes", "no"],
            "description": "Indicates whether the game is available on the Steam platform.",
        },
        "has_linux_release": {
            "type": "string",
            "enum": ["yes", "no"],
            "description": "Indicates whether the game has a Linux release.",
        },
        "has_mac_release": {
            "type": "string",
            "enum": ["yes", "no"],
            "description": "Indicates whether the game has a Mac release.",
        },
        "specifier": {
            "type": "string",
            "description": "A specifier providing additional details or context.",
        },
        "rating": {
            "type": "string",
            "enum": ["excellent", "good", "average", "poor"],
            "description": "The overall rating or opinion of the video game.",
        },
        "player_perspective": {
            "type": "array",
            "items": {
                "type": "string",
                "enum": ["first person", "third person", "side view", "bird view"],
                "description": "The player's perspective or viewpoint while playing the game.",
            },
        },
        "has_multiplayer": {
            "type": "string",
            "enum": ["yes", "no"],
            "description": "Indicates whether the game has a multiplayer mode.",
        },
        "exp_release_date": {
            "type": "string",
            "description": "The expected release date for upcoming games.",
        },
        "developer": {
            "type": "string",
            "description": "The developer or development studio responsible for creating the video game.",
        },
    },
    "required": ["text_type"],
}

def check_schema(s):
    """
    Check compliance with the JSON schema (0 = not compliant, 1 = compliant).
    """
    try:
        jsonschema.validate(instance=json.loads(s), schema=SCHEMA)
        return 1, ""
    except jsonschema.ValidationError as e:
        return 0, str(e)
    except json.JSONDecodeError:
        return 0, "Invalid JSON format"

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Compute the metrics for each approach

df1["schema_compliance"], df1["schema_error"] = list(zip(*[check_schema(s) for s in df1["llm_output"]]))
df1["ground_truth_similarity"] = [
    jaccard_similarity(df1.loc[i, "output"], df1.loc[i, "llm_output"])
    for i in df1.index
]

df2["schema_compliance"], df2["schema_error"] = list(zip(*[check_schema(s) for s in df2["llm_output"]]))
df2["ground_truth_similarity"] = [
    jaccard_similarity(df2.loc[i, "output"], df2.loc[i, "llm_output"])
    for i in df2.index
]

df3["schema_compliance"], df3["schema_error"] = list(zip(*[check_schema(s) for s in df3["llm_output"]]))
df3["ground_truth_similarity"] = [
    jaccard_similarity(df3.loc[i, "output"], df3.loc[i, "llm_output"])
    for i in df3.index
]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
client = dataiku.api_client()
project = client.get_default_project()
mes = project.get_model_evaluation_store(EVALUATION_STORE_ID)

def add_scores(df, s):
    """
    Record the evaluation scores in the evaluation store.
    """
    scores = [
        DSSModelEvaluationStore.MetricDefinition(
            code="Schema compliance",
            value=df["schema_compliance"].mean(),
            name="Average compliance score (1 = full compliance)",
            description="Measures the ratio of generated answers compliant with the target JSON schema"
        ),
        DSSModelEvaluationStore.MetricDefinition(
            code="Similarity with the ground-truth answer",
            value=df["ground_truth_similarity"].mean(),
            name="Average similarity score (1 = identical to the ground-truth answer)",
            description="Measures how close the generated answers are to the ground-truth answers"
        )
    ]

    eval_timestamp = datetime.now().isoformat()
    date_label = DSSModelEvaluationStore.LabelDefinition("evaluation:date", eval_timestamp)
    mes.add_custom_model_evaluation(scores, name=s, labels=[date_label])

add_scores(df1, "gpt-4o with few-shot examples")
add_scores(df2, "Fine-tuned mistral-7b-instruct-v0.2")
add_scores(df3, "Fine-tuned gpt-4o-mini")