from requests.exceptions import HTTPError
import requests
import os
import re
import logging
from dkuconstants import TRANSPORT_MODE
from dkuconstants import CACHE_DIR, CACHE_NUM_SHARDS, CACHE_TIMEOUT
import json
from concurrent.futures import ThreadPoolExecutor
from diskcache import FanoutCache
import pwd
from concurrent.futures import as_completed
import utils

logger = logging.getLogger(__name__)

class GeoRouter:
    def __init__(self, api_endpoint_url, dku_license_id=None):
        self.api_host = api_endpoint_url
        self.dku_license_id = dku_license_id
        self.point_pattern = re.compile(r"POINT\s*\((.*) (.*)\)")

    # General methods
    def get_url(self, endpoint_name):
        return os.path.join(self.api_host, endpoint_name)

    def parse_geopoint(self, geop):
        return self.point_pattern.findall(geop)[0]

    def enrich_headers(self, headers):
        headers["X-DKU-LicenseId"] = self.dku_license_id

    def make_get_request(self, url, params, headers):
        self.enrich_headers(headers)
        response = requests.get(url=url, params=params, headers=headers)
        utils.log_debug_messages(logger, response, params, headers, "GET")
        try:
            response.raise_for_status()
        except HTTPError as err:
            utils.raise_error_400(response, logger)
            if response.content:
                err.args = (f"{err.args[0]}\nResponse content: {response.content}",)
            raise err
        return response

    @staticmethod
    def response_for_empty_point(get_itinerary):
        route_response = {"time": None, "distance": None}
        if get_itinerary is True:
            route_response["itinerary"] = None
        return json.dumps(route_response)

    # Routes
    def parse_route_response(self, response, distance_unit, get_itinerary):
        route_response = {
            "time": round(self.get_time_from_response(response), 3),
            "distance": round(self.get_distance_from_response(response, distance_unit), 3)
        }
        if get_itinerary is True:
            route_response["itinerary"] = utils.geojson_to_wkt(self.get_itinerary_from_response(response))
        return route_response

    def call_route_endpoint(self, from_coords, to_coords, transport_mode, get_itinerary):
        params = self.get_route_querystring(from_coords, to_coords, transport_mode, get_itinerary)
        headers = {}
        return self.make_get_request(url=self.get_url("route"), params=params, headers=headers)

    def get_routes_response(self, parsed_from_coords, parsed_to_coords, real_transport_mode, get_itinerary, distance_unit):
        raw_response = self.call_route_endpoint(parsed_from_coords, parsed_to_coords, real_transport_mode, get_itinerary)
        response = self.parse_route_response(raw_response.json(), distance_unit, get_itinerary)
        return raw_response, response

    def process_row_for_routes(self, row, from_col, to_col, transport_mode_col, transport_mode, get_itinerary, distance_unit, use_cache, cache):
        if transport_mode == TRANSPORT_MODE.PARSE_COL:
            try:
                real_transport_mode = TRANSPORT_MODE(row[transport_mode_col])
            except:
                real_transport_mode = row[transport_mode_col]
        else:
            real_transport_mode = transport_mode
        dep_geopoint = row[from_col]
        arr_geopoint = row[to_col]
        try:
            if str(dep_geopoint) == 'nan':
                return self.response_for_empty_point(get_itinerary)
            else:
                parsed_from_coords = self.parse_geopoint(dep_geopoint)
        except:
            raise SyntaxError("Parsing error: Input column '{0}' should contain geometries in WKT format.".format(from_col))
        try:
            if str(arr_geopoint) == 'nan':
                return self.response_for_empty_point(get_itinerary)
            else:
                parsed_to_coords = self.parse_geopoint(arr_geopoint)
        except:
            raise SyntaxError("Parsing error: Input column '{0}' should contain geometries in WKT format.".format(to_col))
        if use_cache:
            row_key = ",".join([str(dep_geopoint), str(arr_geopoint), str(real_transport_mode), str(get_itinerary), str(distance_unit)])
            # Cache hit
            if row_key in cache:
                response = cache.get(row_key)
                # Case where cache hit, but read timed out because of concurrent operations
                if response is None:
                    logger.info("Cache read timed out - falling back to nominal processing")
                    raw_response, response = self.get_routes_response(parsed_from_coords, parsed_to_coords, real_transport_mode, get_itinerary, distance_unit)
                    if raw_response.status_code == 200:
                        cache[row_key] = response
            # Cache miss
            else:
                raw_response, response = self.get_routes_response(parsed_from_coords, parsed_to_coords, real_transport_mode, get_itinerary, distance_unit)
                if raw_response.status_code == 200:
                    cache[row_key] = response
        else:
            raw_response, response = self.get_routes_response(parsed_from_coords, parsed_to_coords, real_transport_mode, get_itinerary, distance_unit)
        return json.dumps(response)

    def process_routes_df(self,
                          input_dataset,
                          batch_size,
                          parallel_workers,
                          from_col,
                          to_col,
                          transport_mode_col,
                          transport_mode,
                          get_itinerary,
                          distance_unit,
                          use_cache,
                          cache,
                          output_dataset,
                          out_writer):
        first = True
        num_batch = 0
        for input_df in input_dataset.iter_dataframes(chunksize=batch_size):
            if first:
                output_column_names = utils.get_unique_output_column_names(
                    existing_names=input_df.columns
                )
                index_col_name = utils.generate_unique("index", input_df.columns)
            (futures, results) = ([], [])
            with ThreadPoolExecutor(max_workers=parallel_workers) as pool:
                for index, row in input_df.iterrows():
                    row_data = row.to_dict()
                    row_data[index_col_name] = index
                    futures.append(
                        pool.submit(utils.apply_function_with_error_logging,
                                    batch=[row_data],
                                    function=self.process_row_for_routes,
                                    output_column_names=output_column_names,
                                    batch_support=False,
                                    from_col=from_col,
                                    to_col=to_col,
                                    transport_mode_col=transport_mode_col,
                                    transport_mode=transport_mode,
                                    get_itinerary=get_itinerary,
                                    distance_unit=distance_unit,
                                    use_cache=use_cache,
                                    cache=cache)
                    )
                for future in as_completed(futures):
                    results.append(future.result())
            output_df = utils.postprocess_routes_batch(input_df, results, output_column_names, index_col_name, get_itinerary)
            if first:
                output_dataset.write_schema_from_dataframe(output_df)
            out_writer.write_dataframe(output_df)
            first = False
            num_batch += 1
            logger.info("Processed " + str(num_batch) + " batches of " + str(batch_size) + " records.")


    def run_get_routes(self,
                       input_dataset,
                       output_dataset,
                       from_col,
                       to_col,
                       transport_mode,
                       distance_unit,
                       get_itinerary=False,
                       transport_mode_col=None,
                       parallel_workers=10,
                       batch_size=10,
                       use_cache=False,
                       cache_size=1000):
        with output_dataset.get_writer() as out_writer:
            if use_cache:
                cache_dir = os.path.join(pwd.getpwuid(os.getuid()).pw_dir, CACHE_DIR)
                with FanoutCache(cache_dir, shards=CACHE_NUM_SHARDS, size_limit=1000000 * cache_size, timeout=CACHE_TIMEOUT) as cache:
                    self.process_routes_df(input_dataset, batch_size, parallel_workers, from_col, to_col, transport_mode_col, transport_mode, get_itinerary,
                                           distance_unit, use_cache, cache, output_dataset, out_writer)
            else:
                self.process_routes_df(input_dataset, batch_size, parallel_workers, from_col, to_col, transport_mode_col, transport_mode, get_itinerary,
                                       distance_unit, False, None, output_dataset, out_writer)

    # Isochrones
    def parse_isochrone_response(self, response):
        return {
            "isochrone": utils.geojson_to_wkt(self.get_isochrone_geometry_from_response(response)),
        }

    def call_isochrone_endpoint(self, coords, time_threshold, transport_mode):
        params = self.get_isochrone_querystring(coords, time_threshold, transport_mode)
        headers = {}
        return self.make_get_request(url=self.get_url("isochrone"), params=params, headers=headers)

    def get_isochrones_response(self, parsed_geopoint, time_threshold, real_transport_mode):
        raw_response = self.call_isochrone_endpoint(parsed_geopoint, float(time_threshold), real_transport_mode)
        response = self.parse_isochrone_response(raw_response.json())
        return raw_response, response

    def process_row_for_isochrone(self, row, coords_col, transport_mode, transport_mode_col, time_threshold, use_cache, cache):
        if transport_mode == TRANSPORT_MODE.PARSE_COL:
            try:
                real_transport_mode = TRANSPORT_MODE(row[transport_mode_col])
            except:
                real_transport_mode = row[transport_mode_col]
        else:
            real_transport_mode = transport_mode
        geopoint = row[coords_col]
        try:
            if str(geopoint) == 'nan':
                return json.dumps({"isochrone": None})
            else:
                parsed_geopoint = self.parse_geopoint(geopoint)
        except:
            raise SyntaxError("Parsing error: Input column '{0}' should contain geometries in WKT format.".format(coords_col))
        if use_cache:
            row_key = ",".join([str(geopoint), str(time_threshold), str(real_transport_mode)])
            # Cache hit
            if row_key in cache:
                response = cache.get(row_key)
                # Case where cache hit, but read timed out because of concurrent operations
                if response is None:
                    logger.info("Cache read timed out - falling back to nominal processing")
                    raw_response, response = self.get_isochrones_response(parsed_geopoint, time_threshold, real_transport_mode)
                    if raw_response.status_code == 200:
                        cache[row_key] = response
            # Cache miss
            else:
                raw_response, response = self.get_isochrones_response(parsed_geopoint, time_threshold, real_transport_mode)
                if raw_response.status_code == 200:
                    cache[row_key] = response
        else:
            raw_response, response = self.get_isochrones_response(parsed_geopoint, time_threshold, real_transport_mode)
        return json.dumps(response)

    def process_isochrones_df(self,
                              input_dataset,
                              batch_size,
                              parallel_workers,
                              coords_col,
                              transport_mode,
                              transport_mode_col,
                              time_threshold,
                              output_dataset,
                              out_writer,
                              use_cache,
                              cache
                              ):
        first = True
        num_batch = 0
        for input_df in input_dataset.iter_dataframes(chunksize=batch_size):
            if first:
                output_column_names = utils.get_unique_output_column_names(
                    existing_names=input_df.columns
                )
                index_col_name = utils.generate_unique("index", input_df.columns)
            (futures, results) = ([], [])
            with ThreadPoolExecutor(max_workers=parallel_workers) as pool:
                for index, row in input_df.iterrows():
                    row_data = row.to_dict()
                    row_data[index_col_name] = index
                    futures.append(
                        pool.submit(utils.apply_function_with_error_logging,
                                    batch=[row_data],
                                    function=self.process_row_for_isochrone,
                                    output_column_names=output_column_names,
                                    batch_support=False,
                                    coords_col=coords_col,
                                    transport_mode=transport_mode,
                                    transport_mode_col=transport_mode_col,
                                    time_threshold=time_threshold,
                                    use_cache=use_cache,
                                    cache=cache)
                    )
                for future in as_completed(futures):
                    results.append(future.result())
            output_df = utils.postprocess_isochrones_batch(input_df, results, output_column_names, index_col_name)
            if first:
                output_dataset.write_schema_from_dataframe(output_df)
            out_writer.write_dataframe(output_df)
            first = False
            num_batch += 1
            logger.info("Processed " + str(num_batch) + " batches of " + str(batch_size) + " records.")


    def run_get_isochrones(self,
                           input_dataset,
                           output_dataset,
                           coords_col,
                           time_threshold,
                           transport_mode,
                           transport_mode_col=None,
                           batch_size=10,
                           parallel_workers=10,
                           use_cache=False,
                           cache_size=1000):
        with output_dataset.get_writer() as out_writer:
            if use_cache:
                cache_dir = os.path.join(pwd.getpwuid(os.getuid()).pw_dir, CACHE_DIR)
                with FanoutCache(cache_dir, shards=CACHE_NUM_SHARDS, size_limit=cache_size, timeout=CACHE_TIMEOUT) as cache:
                    self.process_isochrones_df(input_dataset, batch_size, parallel_workers, coords_col, transport_mode, transport_mode_col, time_threshold,
                                               output_dataset, out_writer, use_cache, cache)
            else:
                self.process_isochrones_df(input_dataset, batch_size, parallel_workers, coords_col, transport_mode, transport_mode_col, time_threshold,
                                           output_dataset, out_writer, False, None)


    @staticmethod
    def get_isochrone_geometry_from_response(response):
        raise NotImplementedError()

    @staticmethod
    def get_itinerary_from_response(response):
        raise NotImplementedError()

    @staticmethod
    def get_time_from_response(response):
        raise NotImplementedError()

    @staticmethod
    def get_distance_from_response(response, distance_unit):
        raise NotImplementedError()
