import json

from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.messages import SystemMessage, HumanMessage
from pydantic import BaseModel, ValidationError
from typing_extensions import TypedDict
from typing import Optional, List

import dataiku
from dataiku.llm.python import BaseLLM
from dataiku.langchain import LangchainToDKUTracer


llm = dataiku.api_client().get_default_project().get_llm(f"openai:YOUR_OPENAI_CONNECTION_NAME:gpt-5-mini").as_langchain_chat_model()

MAX_ATTEMPTS = 3


# Define a strict Pydantic schema for the data we want to extract.
class ProductReview(BaseModel):
    product_name: str
    rating: int
    review_summary: str
    positive_points: Optional[List[str]] = None
    negative_points: Optional[List[str]] = None


SCHEMA_DEFINITION = ProductReview.model_json_schema()


class ExtractorState(TypedDict):
    input_text: str  # The unstructured text to process
    extracted_json: str  # The LLM's raw JSON string output
    validation_error: Optional[str]  # Error message if parsing fails
    attempt_count: int  # Counter to prevent infinite loops


def generate_json(state):
    print(f"--- Attempt {state['attempt_count'] + 1}: Generating JSON ---")

    # Base prompt
    system_message = (
        "You are an expert data extraction AI. Your task is to extract relevant "
        "information from the user's text and format it as a valid JSON object "
        f"that strictly adheres to the following JSON schema:\n\n{SCHEMA_DEFINITION}\n\n"
        "Only output the raw JSON string. Do not include any other text or markdown."
    )

    # If a validation error exists from a previous attempt, add it to the prompt
    # to guide the LLM's correction.
    if state["validation_error"]:
        print("--- Correction required. Adding error to prompt. ---")
        correction_message = (
            "Your previous attempt failed. Please correct it. "
            f"The error was:\n{state['validation_error']}\n\n"
            "Review the error and the schema, then generate a new, valid JSON string."
        )
        system_message = f"{system_message}\n\n{correction_message}"

    messages_list = [
        SystemMessage(content=system_message),
        HumanMessage(content=state["input_text"])
    ]
    json_output = llm.invoke(messages_list).content

    return {
        "extracted_json": json_output,
        "attempt_count": state["attempt_count"] + 1
    }


def validate_and_parse(state):
    print("--- Validating JSON ---")
    json_string = state["extracted_json"]

    try:
        ProductReview.model_validate_json(json_string)
        print("--- Validation Successful ---")
        return {"validation_error": None}
    except (ValidationError, json.JSONDecodeError) as e:
        print(f"--- Validation FAILED: {e} ---")
        return {"validation_error": str(e)}


def decide_to_finish_or_retry(state):
    if state["validation_error"] is None:
        return "finish"
    elif state["attempt_count"] >= MAX_ATTEMPTS:
        print(f"---Max attempts ({MAX_ATTEMPTS}) reached. Finishing with error.---")
        return "finish"
    else:
        return "retry"


class ExtractorLLM(BaseLLM):
    def __init__(self):
        graph = StateGraph(ExtractorState)
        graph.add_node("generate_json", generate_json)
        graph.add_node("validate_and_parse", validate_and_parse)
        graph.add_edge(START, "generate_json")
        graph.add_edge("generate_json", "validate_and_parse")
        graph.add_conditional_edges(
            "validate_and_parse",
            decide_to_finish_or_retry,
            {
                "retry": "generate_json",  # Loop back to generator
                "finish": END  # Exit the graph
            }
        )
        self.graph = graph.compile()

    def process(self, query, settings, trace):
        input_text = query["messages"][-1]["content"]
        tracer = LangchainToDKUTracer(dku_trace=trace)
        initial_state = {
            "input_text": input_text,
            "attempt_count": 0,
            "extracted_json": None,
            "validation_error": None
        }

        result = self.graph.invoke(
            initial_state,
            config={"callbacks": [tracer]}
        )

        if result["validation_error"] is None:
            final_response = (
                f"Successfully extracted data:\n\n{result['extracted_json']}"
            )
        else:
            final_response = (
                f"Failed to extract valid JSON after {MAX_ATTEMPTS} attempts.\n\n"
                f"Last Error:\n{result['validation_error']}\n\n"
                f"Last Attempted Output:\n{result['extracted_json']}"
            )

        return {"text": final_response}
