import dataiku
import logging
import json
from dataiku.llm.python import BaseLLM
import vertexai
from typing import Optional, AsyncGenerator, Dict, Any, List
from uuid import uuid4

from google.cloud.aiplatform_v1 import ReasoningEngineExecutionServiceAsyncClient
from google.cloud.aiplatform_v1.types import StreamQueryReasoningEngineRequest
from vertexai.agent_engines import AdkApp

from utils import get_credentials_from_vertexai_connection

# Create logger
logger = logging.getLogger("VertexAIAgent")


class VertexAIAgent(BaseLLM):
    def __init__(self):
        pass

    def set_config(self, config, plugin_config):
        self.config = config

    # ========================================================================
    # ADK-Framework FUNCTIONS
    # ========================================================================

    async def _get_or_create_adk_session(
        self,
        adk_app: AdkApp,
        user_id: str
    ) -> Dict[str, Any]:
        """
        Get existing ADK session or create a new one for the user.
        
        Note that "user_id" is actually the conversation id, so there will only ever be 
        a single session per "user"

        Args:
            adk_app: ADK app instance
            user_id: User identifier for session management

        Returns:
            Session dictionary with 'id' key
        """
        logger.info(f"Checking for existing ADK session for user: {user_id}")
        try:
            response = await adk_app.async_list_sessions(user_id=user_id)
            logger.debug(f"Session list response: {response}")

            if len(response.get('sessions', [])) > 0:
                # Since user id is unique for every request from Dataiku, expecting only one session to exist
                session = response['sessions'][0]
                logger.info(f"Found existing ADK session: {session.get('id')}")
                return session
            else:
                logger.info("No existing ADK session found, creating new one")
                session = await adk_app.async_create_session(user_id=user_id)
                logger.info(f"Created new ADK session: {session.get('id')}")
                return session
        except Exception as e:
            logger.error(f"Error managing ADK session: {str(e)}", exc_info=True)
            raise

    async def _parse_adk_response(
        self,
        stream: AsyncGenerator
    ) -> AsyncGenerator[Dict[str, str], None]:
        """
        Parse streaming response from ADK-based agents.

        ADK Response Structure:
        - Each event is a dict with 'content' key
        - content['parts'] is a list of parts
        - Each part can have 'text', 'function_call', or 'function_response'
        - We only stream back 'text' parts (final agent response)

        Example:
        {'content': {'parts': [{'text': 'I converted 1 USD...'}], 'role': 'model'}}

        Args:
            stream: Async stream from ADK

        Yields:
            Dictionary with 'text' key containing response chunks
        """
        logger.info("Parsing ADK response stream")
        try:
            async for event in stream:
                logger.debug(f"ADK event received: {type(event)}")

                # Handle ADK event structure
                if isinstance(event, dict):
                    # Check if event has the standard ADK structure
                    if 'content' in event and isinstance(event['content'], dict):
                        content = event['content']
                        parts = content.get('parts', [])

                        # Extract text from parts
                        for part in parts:
                            if isinstance(part, dict) and 'text' in part:
                                text = part['text']
                                logger.debug(f"ADK text part: {text[:100]}...")
                                yield {"text": text}
                            elif isinstance(part, dict) and 'function_call' in part:
                                # TODO (Future Enhancement): Implement handling for function calls.
                                # This could involve streaming back a message like "Calling tool: [tool_name]..."
                                # to provide intermediate feedback to the user under sources tab in Agent Hub.
                                # Log function calls but don't stream them
                                func_name = part['function_call'].get('name', 'unknown')
                                logger.debug(f"ADK function call: {func_name}")
                            elif isinstance(part, dict) and 'function_response' in part:
                                # TODO (Future Enhancement): Log or handle function responses for debugging.
                                # Log function responses but don't stream them
                                func_name = part['function_response'].get('name', 'unknown')
                                logger.debug(f"ADK function response: {func_name}")
                            else:
                                continue
                    else:
                        # Fallback: check for direct text field
                        if 'text' in event:
                            text = str(event['text'])
                            logger.debug(f"ADK direct text: {text[:100]}...")
                            yield {"text": text}
                        else:
                            logger.debug(f"ADK event without text content: {list(event.keys())}")

            logger.info("ADK response stream completed")
        except Exception as e:
            logger.error(f"Error parsing ADK response: {str(e)}", exc_info=True)
            raise

    async def _process_adk_stream(
        self,
        gcp_credentials,
        gcp_region: str,
        gcp_project: str,
        agent_resource_name: str,
        messages: List[Dict[str, Any]],
        convo_id: Optional[str]
    ) -> AsyncGenerator[Dict[str, str], None]:
        """
        Process streaming query for Google ADK-based agents.
        """
        logger.info("Processing Google ADK stream")

        # Initialize Vertex AI client for ADK
        vertexai_client = vertexai.Client(
            project=gcp_project,
            location=gcp_region,
            credentials=gcp_credentials
        )
        logger.info("Vertex AI client initialized for ADK")

        # Get ADK App object
        logger.info(f"Getting ADK App object: {agent_resource_name}")
        adk_app = vertexai_client.agent_engines.get(name=agent_resource_name)
        logger.info("ADK App object retrieved successfully")

        # Determine user ID: conversation_id if exists, otherwise UUID4
        if convo_id:
            user_id = convo_id
            use_session = True
            logger.info(f"Conversation ID found: {convo_id}, will use session management")
        else:
            user_id = str(uuid4())
            use_session = False
            logger.warning(f"No conversation ID found, generated a random one: {user_id}, skipping session management")

        logger.info(f"Using 'user' (conversation) ID: {user_id}")

        # Extract the latest user message
        # For ADK, we send only the latest message as the session maintains context
        latest_user_message = ""
        for msg in reversed(messages):
            if msg.get("role") == "user" and msg.get("content"):
                latest_user_message = msg.get("content")
                break

        logger.info(f"Latest user message: {latest_user_message[:100]}...")

        # Create ADK request and get stream
        try:
            if use_session:
                # With conversation ID: use session management
                logger.info("Using session management for ADK stream")
                session = await self._get_or_create_adk_session(adk_app, user_id)
                session_id = session.get('id')

                stream = adk_app.async_stream_query(
                    user_id=user_id,
                    session_id=session_id,
                    message=latest_user_message
                )
            else:
                # Without conversation ID: skip session management, send with user_id only
                logger.info("Skipping session management, sending request with user_id only")
                stream = adk_app.async_stream_query(
                    user_id=user_id,
                    message=latest_user_message,
                )

            # Parse and yield responses using framework-specific parser
            async for chunk in self._parse_adk_response(stream):
                yield chunk

        except Exception as e:
            logger.error(f"Failed to query ADK Agent Engine: {str(e)}", exc_info=True)
            raise

    # ========================================================================
    # Langchain-Framework FUNCTIONS
    # ========================================================================

    def _create_langchain_input_payload(self, messages):
        """
        Convert Dataiku messages to Vertex AI Langchain Agent Engine format.

        Args:
            messages: List of Dataiku messages with 'role' and 'content' fields
                      Example: [{"role": "user", "content": "Hello"}, ...]

        Returns:
            List of messages in Agent Engine format:
            [{"role": "user/assistant/system", "type": "text", "text": "..."}]

        Filters:
            - Only includes messages with roles: user, system, assistant
            - For assistant messages, only includes those with text content. No support for images or dataiku tools.
            - Skips messages with empty content
        """
        logger.info(f"Converting {len(messages)} Dataiku messages to Langchain format")

        agent_messages = []
        for idx, msg in enumerate(messages):
            role = msg.get("role", "")
            content = msg.get("content", "")

            # Filter: only include user, system, and assistant roles
            if role not in ["user", "system", "assistant"]:
                logger.info(f"Message #{idx+1} - Skipping message with role '{role}' (not user/system/assistant)")
                continue

            # For assistant messages, check if content has text
            if role == "assistant":
                # Skip assistant messages without text content
                if not content or not isinstance(content, str) or content.strip() == "":
                    logger.info(f"Message #{idx+1} - Skipping assistant message without text content")
                    continue

            # Ensure content is not empty for user and system messages too
            if not content or content.strip() == "":
                logger.info(f"Message #{idx+1} - Skipping {role} message with empty content")
                continue

            agent_msg = {
                "role": role,
                "type": "text",
                "text": content
            }
            agent_messages.append(agent_msg)
            logger.info(f"Message #{idx+1} - Role: {role}, Text: {content[:100]}..." if len(content) > 100 else f"Message #{idx+1} - Role: {role}, Text: {content}")

        logger.info(f"Converted to {len(agent_messages)} valid Langchain messages")
        return agent_messages


    async def _parse_langchain_response(
        self,
        stream: AsyncGenerator
    ) -> AsyncGenerator[Dict[str, str], None]:
        """
        Parse streaming response from Langchain-based agents.

        The Langchain Reasoning Engine returns HTTP streaming responses where:
        - Each response chunk is an HttpBody object with binary data (bytes)
        - response.data contains bytes that must be decoded to UTF-8 string
        - Decoded string is JSON containing an "output" field with agent's text response

        Args:
            stream: Async stream from the Reasoning Engine (yields HttpBody objects)

        Yields:
            Dictionary with 'text' key containing response chunks

        Processing Flow:
            1. Receive HttpBody response from stream
            2. Decode response.data from bytes to UTF-8 string (line 295)
            3. Parse JSON string and extract "output" field
            4. Yield text content back to Dataiku
        """
        logger.info("Parsing Langchain response stream")
        try:
            async for response in stream:
                # response is HttpBody
                if response.data:
                    chunk = response.data.decode("utf-8")
                    try:
                        chunk_data = json.loads(chunk)
                        items = chunk_data if isinstance(chunk_data, list) else [chunk_data]
                        for item in items:
                            # TODO (Future Enhancement): Parse 'steps' and 'actions' from the Langchain stream.
                            # This would allow providing intermediate feedback on tool usage, similar to the ADK's function_call.
                            # The structure often includes details about which tool is being invoked.
                            if isinstance(item, dict) and "output" in item:
                                output_text = str(item["output"])
                                logger.debug(f"Langchain chunk: {output_text[:100]}...")
                                yield {"text": output_text}
                    except json.JSONDecodeError:
                        logger.warning(f"Failed to parse Langchain chunk as JSON: {chunk}")
                        pass
            logger.info("Langchain response stream completed")
        except Exception as e:
            logger.error(f"Error parsing Langchain response: {str(e)}", exc_info=True)
            raise

    async def _process_langchain_stream(
        self,
        gcp_credentials,
        gcp_region: str,
        agent_resource_name: str,
        messages: List[Dict[str, Any]]
    ) -> AsyncGenerator[Dict[str, str], None]:
        """
        Process streaming query for Langchain-based agents.
        """
        logger.info("Processing Langchain stream")

        # Create ReasoningEngineExecutionServiceAsyncClient with region endpoint
        reasoning_client = ReasoningEngineExecutionServiceAsyncClient(
            credentials=gcp_credentials,
            client_options={"api_endpoint": f"{gcp_region}-aiplatform.googleapis.com"}
        )

        # Convert messages to Langchain format
        agent_messages = self._create_langchain_input_payload(messages)
        logger.info(f"Transformed {len(messages)} messages to {len(agent_messages)} Langchain messages")

        # Create the request inline - no separate function needed
        request = StreamQueryReasoningEngineRequest(
            name=agent_resource_name,
            input={"input": {"input": agent_messages}}
        )
        logger.info(f"Created Langchain request with {len(agent_messages)} messages")

        logger.info("Sending async stream query to Langchain Agent Engine")

        # Execute the query
        try:
            stream = await reasoning_client.stream_query_reasoning_engine(request=request)
            logger.info("Successfully received stream response from Langchain Agent Engine")

            # Parse and yield responses using framework-specific parser
            async for chunk in self._parse_langchain_response(stream):
                yield chunk

        except Exception as e:
            logger.error(f"Failed to query Langchain Agent Engine: {str(e)}", exc_info=True)
            raise

    # ========================================================================
    # MAIN ENTRY POINT
    # ========================================================================

    async def aprocess_stream(self, query, settings, trace):
        """
        Main entry point for processing streaming queries.
        Routes to appropriate framework-specific implementation.
        """

        logger.info(f"Processing async stream query for Vertex AI Agent")
        logger.info(f"Full Query object ######: {query}")

        # Get configuration parameters
        connection_name = self.config.get("vertexai_connection").strip()
        agent_framework = self.config.get("agent_framework").strip()
        agent_resource_name = self.config.get("agent_id").strip()

        logger.info(f"Connection: {connection_name}")
        logger.info(f"Agent Framework: {agent_framework}")
        logger.info(f"Agent Resource Name: {agent_resource_name}")

        # Get the connection
        client = dataiku.api_client()
        connection = client.get_connection(connection_name)
        connection_info = connection.get_info()

        # Get credentials
        logger.info("Retrieving GCP credentials")
        gcp_credentials = get_credentials_from_vertexai_connection(connection_info)

        # Get region and project from connection params
        conn_params = connection_info.get_params()
        gcp_region = conn_params.get('region', 'us-central1')
        gcp_project = conn_params.get('project')
        logger.info(f"Using GCP Project: {gcp_project}")
        logger.info(f"Using GCP Region: {gcp_region}")

        # Convert Dataiku messages to Vertex AI Agent Engine format
        messages = query.get("messages", [])
        logger.info(f"Total messages in conversation: {len(messages)}")

        # Route to appropriate framework implementation
        if agent_framework == "langchain":
            logger.info("Using Langchain framework")
            async for chunk in self._process_langchain_stream(
                gcp_credentials=gcp_credentials,
                gcp_region=gcp_region,
                agent_resource_name=agent_resource_name,
                messages=messages
            ):
                yield chunk
        elif agent_framework == "google-adk":
            logger.info("Using Google ADK framework")
            # Extract conversation ID from Agent Hub. ADK doesn't support chat hisotry as input. 
            # Chat context is only supported via session id. We use the conversationId from the context for that
            convo_id = query.get("context", {}).get("conversationId")
            if not convo_id:
                logger.warning("No conversation id, history will not be handled")
                yield { "chunk": {
                    "type": "event",
                    "eventKind": "MISSING_CONVERSATION_ID",
                    "eventData": {"warning": "conversationId not found. ADK agents will not work properly without a conversationId. Agent will not remember any previous message"
                }}}
                   
            async for chunk in self._process_adk_stream(
                gcp_credentials=gcp_credentials,
                gcp_region=gcp_region,
                gcp_project=gcp_project,
                agent_resource_name=agent_resource_name,
                messages=messages,
                convo_id=convo_id
            ):
                yield chunk
        else:
            # TODO: Add support for other frameworks supported by Vertex AI Agent Engine like LangGraph, AG1, etc.
            error_msg = f"Unsupported agent framework: {agent_framework}. Supported frameworks: 'langchain', 'google-adk'"
            logger.error(error_msg)
            raise ValueError(error_msg)

    # ========================================================================
    # FRAMEWORK-SPECIFIC PROCESSORS
    # ========================================================================

    

   
