# coding: utf-8
from __future__ import unicode_literals

from dataiku.llm.evaluation.agent_evaluation_recipe import AgentEvaluationRecipe
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataiku.llm.evaluation.llm_evaluation_recipe import LLMEvaluationRecipe

"""
Main entry point for testing custom metrics
This is a server implementing commands defined in the PythonKernelProtocol Java class
"""

import logging

from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging

logger = logging.getLogger(__name__)


class TestCustomMetricParams(object):
    def __init__(self, data):
        self._data = data

    @property
    def recipe_desc(self):
        return self._data["recipeDesc"]

    @property
    def input_dataset_smartname(self):
        return self._data.get("inputDatasetSmartName")

    @property
    def index_of_metric_to_compute(self):
        return self._data.get("indexOfMetricToCompute")

    @property
    def isAgent(self):
        return self._data.get("isAgent")


class TestCustomMetricProtocol(object):
    def __init__(self, link):
        self.link = link

    def _handle_test_custom_metric(self, params):
        result = None
        try:
            parsed_recipe_desc = GenAIEvalRecipeDesc(params.recipe_desc)
            if params.isAgent:
                recipe = AgentEvaluationRecipe.build_for_test(parsed_recipe_desc, params.input_dataset_smartname)
            else:
                recipe = LLMEvaluationRecipe.build_for_test(parsed_recipe_desc, params.input_dataset_smartname)
            result = recipe.test_custom_metric(parsed_recipe_desc.custom_metrics[params.index_of_metric_to_compute])
        except Exception as e:
            fakeResult = {"metric": None, "didSucceed": False, "error": str(e)}
            self.link.send_json(fakeResult)
            return

        self.link.send_json(result)

    def start(self):
        try:
            while True:
                command = self.link.read_json()
                # if command["type"] == "TestCustomMetricCommand":
                test_params = TestCustomMetricParams(command)
                self._handle_test_custom_metric(test_params)
        except EOFError:
            logger.info("Connection with the test custom metric client closed")


def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()
    protocol_handler = TestCustomMetricProtocol(link)
    try:
        protocol_handler.start()
    finally:
        link.close()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
