from dataiku.llm.guardrails import BaseGuardrail
import dataiku
import logging

class TopicsBoundariesGuardrail(BaseGuardrail):
    def set_config(self, config, plugin_config):
        self.config = config


    def process(self, input, trace):
        if "completionResponse" in input:
            text_to_check = input["completionResponse"]["text"]
            present_topics = self._get_topics(trace, text_to_check)

            logging.info("Present topics: %s" % (present_topics,))
 
            if self.config["mode"] == "ALLOWLIST":
                if len(present_topics) == 0:
                    self._handle_response_failure(input, trace, "None of the allowed topics are present")
            elif self.config["mode"] == "DENYLIST":
                if len(present_topics) > 0:
                    self._handle_response_failure(input, trace, "A forbidden topic is present: %s" % present_topics)
            else:
                raise Exception("Invalid mode: %s" % self.config["mode"])

        elif "completionQuery" in input:

            text_to_check = "\n".join([m["content"] for m in input["completionQuery"]["messages"] if m["role"] == "user"])
            present_topics = self._get_topics(trace, text_to_check)

            logging.info("Present topics (query): %s" % (present_topics,))
 
            if self.config["mode"] == "ALLOWLIST":
                if len(present_topics) == 0:
                    self._handle_query_failure(input, trace, "None of the allowed topics are present")
            elif self.config["mode"] == "DENYLIST":
                if len(present_topics) > 0:
                    self._handle_query_failure(input, trace, "A forbidden topic is present: %s" % present_topics)
            else:
                raise Exception("Invalid mode: %s" % self.config["mode"])

        return input


    def _handle_response_failure(self, input, trace, msg):
        logging.info("Topics Boundaries not respected (action:%s): %s" % (self.config["action"], msg))
        if self.config["action"] == "REJECT":
            input["responseGuardrailResponse"] = {
                "action": "FAIL",
                "error": {
                    "message": msg
                }
            }
        elif self.config["action"] == "AUDIT":
            input["responseGuardrailResponse"] = {
                "action": "PASS_WITH_AUDIT",
                "auditData": [{
                    "origin": "Topics Boundaries",
                    "violation" : msg
                }]
            }
        elif self.config["action"] == "DECLINE":
            input["responseGuardrailResponse"] = {
                "action": "RESPOND"
            }
            input["completionResponse"]["text"] = "I do not feel comfortable talking about this"

    def _handle_query_failure(self, input, trace, msg):
        logging.info("Topics Boundaries not respected (action:%s): %s" % (self.config["action"], msg))
        if self.config["action"] == "REJECT":
            input["queryGuardrailResponse"] = {
                "action": "FAIL",
                "error": {
                    "message": msg
                }
            }
        elif self.config["action"] == "AUDIT":
            input["queryGuardrailResponse"] = {
                "action": "PASS_WITH_AUDIT",
                "auditData": [{
                    "origin": "Bias Detection",
                    "violation" : msg
                }]
            }
        elif self.config["action"] == "DECLINE":
            input["queryGuardrailResponse"] = {
                "action": "RESPOND",
                "overridenResponseText": "I do not feel comfortable talking about this",
                "auditData": [{
                    "origin": "Bias Detection",
                    "violation" : msg
                }]
            }

    def _get_topics(self, trace, text):
        topics_list = "\n".join(["* %s" % t for t in self.config["topics"]])
        llm = dataiku.api_client().get_default_project().get_llm(self.config["llm"])
        with trace.subspan("Listing topics present in the text") as subspan:

            resp = llm.new_completion().with_message("""
I want to check whether the text below matches a list of allowed or forbidden topics.

Given a list of possible topics, and the text, please return a JSON list, indicating which of the topics in the list are actually 
present in the text. If none of the possible topics are in the text, it's OK to return an empty JSON string.

Please respond with just the valid JSON list. Nothing else.

Here is the list of possible topics:

%s

The text to check follows

""" % topics_list, "system").with_message(text).execute()

        return resp.json
