import json
import logging
import traceback

from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import watch_stdin, get_json_friendly_error
from dataiku.core import debugging
from dataiku.doctor.timeseries_interactive.commands import create_scenario, compute_scenarios_forecasts

logger = logging.getLogger(__name__)

class TimeseriesInteractiveScoringServer(object):
    def __init__(self, link):
        self.link = link

    def start(self):
        logger.info("Timeseries Interactive Scoring Server started")
        self.link.connect()
        logger.info("Link connected")

        while True:
            try:
                request = self.link.read_json()
                if request is None:
                    break
                cmd = request.get("type")
                params = json.loads(request.get("params"))
                logger.info("Timeseries Interactive Scoring - Command %s" % cmd)
                result = {}
                if cmd == "create_scenario":
                    result["scenarios"] = create_scenario(
                        split_desc=params.get("split_desc"),
                        core_params=params.get("core_params"),
                        preprocessing_folder=params.get("preprocessing_folder"),
                        model_folder=params.get("model_folder"),
                        split_folder=params.get("split_folder"),
                        computation_parameters=params.get("computation_params")
                    )
                elif cmd == "compute_scenario":
                    result["forecasts"] = compute_scenarios_forecasts(
                        split_desc=params.get("split_desc"),
                        core_params=params.get("core_params"),
                        preprocessing_folder=params.get("preprocessing_folder"),
                        model_folder=params.get("model_folder"),
                        split_folder=params.get("split_folder"),
                        computation_parameters=params.get("computation_params")
                    )
                else:
                    raise Exception("Unknown command: %s" % cmd)
                self.link.send_json(result)

            except Exception as e:
                traceback.print_exc()
                traceback.print_stack()
                logger.exception("Error while processing request", e)
                error = get_json_friendly_error()
                self.link.send_json({"error": error})

        self.link.close()
        logger.info("Timeseries Interactive Scoring Server stopped")

def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)

    scoring_server = TimeseriesInteractiveScoringServer(link)
    scoring_server.start()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
