from uuid import uuid4
import time
import logging
import json

import dataiku
from dataiku.llm.python import BaseLLM

import aws_bedrock_agents.auth
import aws_bedrock_agents.boto3
import aws_bedrock_agents.event

logger = logging.getLogger("awsbedrockagentplugin")
logging.basicConfig(
    level=logging.INFO,
    format='AWS Bedrock Agent plugin %(levelname)s - %(message)s'
)

class DkuMissingPluginParameter(Exception):
    pass

class AwsBedrockAgent(BaseLLM):
    def __init__(self):
        self.dss_client = dataiku.api_client()
        
    def set_config(self, config, plugin_config):
        """
        Unpack and validate plugin config.
        Initialize AWS credentials attributes.
        """
        # Unpack and validate plugin config
        self.aws_region = config.get("aws_region")
        self.agent_mode = config.get("agent_mode", "agent")
        self.s3_connection_name = config.get("s3_connection")

        if not self.aws_region:
            raise DkuMissingPluginParameter("Configuration error: missing AWS Region.")
        if not self.s3_connection_name:
            raise DkuMissingPluginParameter("Configuration error: missing S3 connection.")
        if not self.agent_mode:
            raise DkuMissingPluginParameter("Configuration error: missing Agent Mode.")

        if self.agent_mode == "agent":
            self.agent_id = config.get("agent_id")
            self.alias_id = config.get("alias_id")
            if not self.agent_id:
                raise DkuMissingPluginParameter("Configuration error: missing Bedrock Agent ID.")
            if not self.alias_id:
                raise DkuMissingPluginParameter("Configuration error: missing Bedrock Agent Alias ID.")
        elif self.agent_mode == "agentcore":
            self.agentcore_runtime_arn = config.get("agentcore_runtime_arn")
            self.agentcore_qualifier = config.get("agentcore_qualifier")

            if not self.agentcore_runtime_arn:
                raise DkuMissingPluginParameter("Configuration error: missing Bedrock AgentCore ARN.")
        else:
            raise DkuMissingPluginParameter(f"Configuration error: invalid Agent Mode '{self.agent_mode}'.")

        self.credentials_expiration = 0
        self.aws_auth = {}

    def _refresh_credentials_if_needed(self):
        current_ts = time.time() * 1000 + 10000 # 10s buffer
        if current_ts >= self.credentials_expiration:
            self.aws_auth, self.credentials_expiration = aws_bedrock_agents.auth.get_aws_credentials(
                self.dss_client, self.s3_connection_name)

        logger.info(f"Retrieved AWS credentials with expiration timestamp {self.credentials_expiration}ms.")


    def process_stream(self, query, settings, trace):       
        """
        High-level logic of querying a Bedrock agent from DSS:
        
        1. Refresh AWS credentials if needed
        2. Invoke Bedrock agent:
            - If error, return it and end generator loop.
        3. Process Bedrock agent response in stream:
            - If agent is attempting to "return control" to the user, return an
              error and end generator loop (as DSS agents don't currently support HITL).
            - If present, yield agent text response
            - If present, extract KB sources and tool calling information (both of which
              are returned in the 'sources')
        4. Once the agent response stream is complete:
            - Return sources.
            - Add a subspan to the trace (with agent inputs, outputs, sources and metadata)
        """

        conversation_id = query.get("context", {}).get("conversationId")
        if not conversation_id:
            logger.warning("No conversation id, history will not be handled")
            yield { "chunk": {
                "type": "event",
                "eventKind": "MISSING_CONVERSATION_ID",
                "eventData": {"warning": "conversationId not found. Bedrock agents will not work properly without a conversationId. Agent will not remember any previous message"
            }}}

        user_message = query["messages"][-1]["content"]

        # (1) Handle credentials
        self._refresh_credentials_if_needed()

        # (2) Invoke Bedrock agent or AgentCore based on mode
        with trace.subspan("AWS_BEDROCK_AGENT_CALL") as subspan:
            subspan.inputs["messages"] = [{ "role": "user", "text": user_message}]
            subspan.attributes["AwsBedrockAgentMode"] = self.agent_mode

            if self.agent_mode == "agentcore":
                subspan.attributes["AwsBedrockAgentcoreRuntimeArn"] = self.agentcore_runtime_arn
                subspan.attributes["AwsBedrockAgentcoreQualifier"] = self.agentcore_qualifier

                logger.info(
                    f"Invoking Bedrock AgentCore '{self.agentcore_runtime_arn}' (version '{self.agentcore_qualifier}') " + \
                    f"in region {self.aws_region} with conversation ID '{conversation_id}'."
                )

                response = aws_bedrock_agents.boto3.invoke_agentcore(
                    aws_auth=self.aws_auth,
                    aws_region=self.aws_region,
                    agent_runtime_arn=self.agentcore_runtime_arn,
                    qualifier=self.agentcore_qualifier,
                    user_message=user_message,
                    conversation_id=conversation_id
                )

                content = []

                if "text/event-stream" in response.get("contentType", ""):
                    for line in response["response"].iter_lines(chunk_size=10):
                        if line:
                            line = line.decode("utf-8")
                            if line.startswith("data: "):
                                line = line[6:]

                                # yield { "chunk" :{
                                #     "type": "event", "eventKind": "BASIC_LINE",
                                #     "eventData": {"line": line}
                                # }}
                                # logger.info("AgentCore stream line: %s" % line)

                                try:
                                    line_data = json.loads(line)
                                except json.JSONDecodeError:
                                    logger.warning(f"Could not decode AgentCore stream line as JSON: {line}")
                                    yield { "chunk" :{
                                        "type": "event", "eventKind": "MALFORMED_AGENTCORE_STREAM_LINE",
                                        "eventData": {"line": line}
                                    }}
                                    continue

                                if type(line_data) == str:
                                    continue

                                # Actual data message
                                if "event" in line_data and "contentBlockDelta" in line_data["event"]:
                                    block_delta = line_data["event"]["contentBlockDelta"]
                                    if "delta" in block_delta and "text" in block_delta["delta"]:
                                        yield {"text": block_delta["delta"]["text"]}
                                        content.append(block_delta["delta"]["text"])

                                # Internal messages
                                elif "event" in line_data and "messageStart" in line_data["event"]:
                                    logger.info("AgentCore stream: got 'messageStart': %s" % line_data)
                                elif "event" in line_data and "contentBlockStop" in line_data["event"]:
                                    logger.info("AgentCore stream: got 'contentBlockStop': %s" % line_data)
                                elif "event" in line_data and "messageStop" in line_data["event"]:
                                    logger.info("AgentCore stream: got 'messageStop': %s" % line_data)
                                elif "init_event_loop" in line_data:
                                    logger.info("AgentCore stream: got 'init_event_loop': %s" % line_data)
                                elif "start_event_loop" in line_data:
                                    logger.info("AgentCore stream: got 'start_event_loop': %s" % line_data)
                                elif "start" in line_data:
                                    logger.info("AgentCore stream: got 'start': %s" % line_data)
                                elif "message" in line_data and "role" in line_data["message"]:
                                    logger.info("AgentCore stream: got message: %s" % line_data)
                                elif "event" in line_data and "metadata" in line_data["event"]:
                                    logger.info("AgentCore stream: got metadata: %s" % line_data)
                                    # TODO: Log usage from metadata
                                    #"event":{"metadata":{"usage":{"inputTokens":13,"outputTokens":120,"totalTokens":133},"metrics":{"latencyMs":1967}}}

                                # Unknown, raise an event
                                else:
                                    logger.warning(f"Unknown AgentCore stream line data: {line_data}")
                                    yield { "chunk" :{
                                        "type": "event", "eventKind": "AWS_BEDROCK_AGENTCORE_UNKNOWN_STREAM_CHUNK",
                                        "eventData": {"lineData": line_data}
                                    }}

                elif response.get("contentType") == "application/json":

                    # TODO: Implement proper parsing of non-streaming AgentCore response
                    for chunk in response.get("response", []):
                        content.append(chunk.decode('utf-8'))

                    yield {"text": "".join(content)}

                else:
                    yield {"text": response }

                if len(content) > 0:
                    subspan.outputs["text"] = "".join(content)

            else:
                subspan.attributes["AwsBedrockAgentId"] = self.agent_id
                subspan.attributes["AwsBedrockAgentAliasId"] = self.alias_id

                logger.info(
                    f"Invoking Bedrock agent '{self.agent_id}' (alias '{self.alias_id}') in region {self.aws_region} " + \
                    f"with conversation ID '{conversation_id}'."
                )

                response = aws_bedrock_agents.boto3.invoke_agent(
                    aws_auth=self.aws_auth,
                    aws_region=self.aws_region,
                    agent_id=self.agent_id,
                    alias_id=self.alias_id,
                    user_message=user_message,
                    conversation_id=conversation_id
                )

                # (3) Process Bedrock agent response
                full_text_response = ""
                sources = []
                kb_source_items = []
                tool_sources = {}

                logger.info("Processing Bedrock agent response in stream")
                for event in response.get("completion", []):
                    # Check if Agent is attempting to return control to the user (i.e. request tool call confirmation).
                    if "returnControl" in event:
                        agent_identifier = f"'{self.agent_id}' (alias '{self.alias_id}')"
                        yield {"text": aws_bedrock_agents.event.extract_return_control_error(event, agent_identifier)}
                        # log error
                        return

                    # Stream successful agent text response
                    if "chunk" in event:
                        full_text_response += event["chunk"]["bytes"].decode()
                        yield {"text": event["chunk"]["bytes"].decode()}

                    # Extract Bedrock Knowledge Bank citations and Action Group invocations from the event (if present)
                    aws_bedrock_agents.event.extract_kb_source_items(event, kb_source_items)
                    aws_bedrock_agents.event.extract_tool_call_sources(event, tool_sources)

                logger.info(f"Bedrock agent response: {full_text_response}")

                # (4) Process and yield kb and tool call sources
                aws_bedrock_agents.event.process_kb_and_tool_sources(sources, kb_source_items, tool_sources)
                if sources:
                    logger.info(f"Bedrock agent response sources: {sources}")
                    yield {"footer": {"additionalInformation": {"sources": sources}}}

                # Add more info in trace
                subspan.outputs["text"] = full_text_response
                subspan.attributes["AwsBedrockKnowledgeBankSources"] = kb_source_items
                subspan.attributes["AwsBedrockToolCalls"] =  [tool for _, tool in tool_sources.items()]
