# -*- coding: utf-8 -*-
from dataiku import Dataset
from dataiku.customrecipe import (
    get_plugin_config,
    get_recipe_config,
    get_input_names_for_role,
    get_output_names_for_role,
)
from dkulib.dku_config import DkuConfig
from dkuconstants import TRANSPORT_MODE
from dkuconstants import DISTANCE_UNIT
from dkuconstants import REGIONS

class DkuConfigLoading:
    def __init__(self):
        self.plugin_config = get_plugin_config()
        self.config = get_recipe_config()
        self.dku_config = DkuConfig()

        input_dataset = get_input_names_for_role("input_role")[0]
        self.dku_config.add_param(
            name="input_dataset", value=Dataset(input_dataset), required=True
        )
        self.input_dataset_columns = [
            p["name"] for p in self.dku_config.input_dataset.read_schema()
        ]

    def _content_error_message(self, error, column):
        """Get corresponding error message if any"""

        if error == "missing":
            return "Missing input column."

        if error == "invalid":
            return "Invalid input column : {}.\n".format(column)

        if error == "transport_mode":
            return "You must select one of the transports mode.\n If your dataset contains several transports mode, you can create a column in your input dataset containing the transports mode you want to use, and select 'In column' from the dropdown.\n"

        if error == "time_threshold":
            return "You must enter a number in minutes."

    def _add_api_config(self):
        endpoint = self.plugin_config.get("endpoint","auto")
        manual_api_endpoint_url = self.plugin_config.get("manualEndpointURL","") 
        api_endpoint_url = manual_api_endpoint_url if (0 < len(manual_api_endpoint_url)) else REGIONS[endpoint]
        self.dku_config.add_param(name="api_endpoint_url", value=api_endpoint_url, required=True)

    def _add_transport_mode(self):
        self.dku_config.add_param(
            name="transport_mode", value=self.config.get("transport_mode"), required=True, cast_to=TRANSPORT_MODE, checks=[
                {
                    "type": "exists",
                    "err_msg": self._content_error_message("transport_mode", None),
                },
                {
                    "type": "in",
                    "op": [e for e in TRANSPORT_MODE],
                    "err_msg": self._content_error_message("transport_mode", None),
                },
            ],
        )

    def _add_transport_mode_column(self):
        transport_mode_column = self.config.get("transport_mode_column")
        self.dku_config.add_param(
            name="transport_mode_column",
            value=transport_mode_column,
            required=True,
            checks=[
                {
                    "type": "exists",
                    "err_msg": self._content_error_message("missing", None),
                },
                {
                    "type": "in",
                    "op": self.input_dataset_columns,
                    "err_msg": self._content_error_message("invalid", transport_mode_column),
                },
            ],
        )

    def _get_column_checks(self, column, input_columns):
        """Check for mandatory columns parameters"""

        return [
            {
                "type": "exists",
                "err_msg": self._content_error_message("missing", None),
            },
            {
                "type": "in",
                "op": input_columns,
                "err_msg": self._content_error_message("invalid", column),
            },
        ]

    def _add_output_dataset(self):
        output_dataset_name = get_output_names_for_role("main_output")[0]
        self.dku_config.add_param(
            name="output_dataset",
            value=Dataset(output_dataset_name),
            required=True,
        )

    def _add_parallelism_settings(self):
        self.dku_config.add_param(name="batch_size", value=self.plugin_config.get("batch_size"), required=True)
        self.dku_config.add_param(name="routes_parallel_workers", value=self.plugin_config.get("routes_parallel_workers"), required=True)
        self.dku_config.add_param(name="isochrones_parallel_workers", value=self.plugin_config.get("isochrones_parallel_workers"), required=True)

    def _add_cache_settings(self):
        self.dku_config.add_param(name="cache_size", value=1024*1024*self.plugin_config.get("cache_size"), required=True)
        self.dku_config.add_param(name="use_cache", value=self.config.get("use_cache"), required=True)

class DkuConfigLoadingRouting(DkuConfigLoading):

    def load_settings(self):

        self._add_api_config()
        self._add_required_columns()
        self._add_transport_mode()
        self._add_parallelism_settings()
        self._add_cache_settings()

        if self.dku_config.transport_mode == TRANSPORT_MODE.PARSE_COL:
            self._add_transport_mode_column()
        else:
            self.dku_config.add_param(
                name="transport_mode_column",
                value=None,
            )
        self._add_distance_unit()
        self._add_itinerary()
        self._add_output_dataset()
        return self.dku_config

    def _add_required_columns(self):
        self.dku_config.add_param(name="from_column", value=self.config.get("from_column"), required=True,
                                  checks=self._get_column_checks(self.config.get("from_column"), self.input_dataset_columns))
        self.dku_config.add_param(name="to_column", value=self.config.get("to_column"), required=True,
                                  checks=self._get_column_checks(self.config.get("to_column"), self.input_dataset_columns))

    def _add_itinerary(self):
        self.dku_config.add_param(name="get_itinerary", value=self.config.get("get_itinerary"), required=True)

    def _add_distance_unit(self):
        self.dku_config.add_param(name="distance_unit", value=self.config.get("distance_unit"), cast_to=DISTANCE_UNIT, required=True)

class DkuConfigLoadingIsochrone(DkuConfigLoading):

    def load_settings(self):

        self._add_api_config()
        self._add_required_column()
        self._add_time_threshold()
        self._add_transport_mode()
        self._add_parallelism_settings()
        self._add_cache_settings()
        if self.dku_config.transport_mode == TRANSPORT_MODE.PARSE_COL:
            self._add_transport_mode_column()
        else:
            self.dku_config.add_param(
                name="transport_mode_column",
                value=None,
            )
        self._add_output_dataset()
        return self.dku_config

    def _add_required_column(self):
        self.dku_config.add_param(name="coords_column", value=self.config.get("coords_column"), required=True,
                                  checks=self._get_column_checks(self.config.get("coords_column"), self.input_dataset_columns))

    def _add_time_threshold(self):
        self.dku_config.add_param(name="time_threshold", value=self.config.get("time_threshold"), required=True,
                                  checks=[{"type": "is_castable", "op": float, "err_msg": self._content_error_message("time_threshold", None)}])
