import pandas as pd
import inspect
from copy import deepcopy
import json
import re
from shapely.geometry import shape, MultiPolygon
from dkuconstants import KNOWN_ERROR_MESSAGES
from dkuconstants import OUTPUT_COLUMN_NAME_DESCRIPTIONS
import logging
from dkuconstants import TRANSPORT_MODE
from collections import namedtuple
from operator import itemgetter
from more_itertools import flatten
from sqlite3 import OperationalError

class BatchError(ValueError):
    """Custom exception raised if the Batch function fails"""


def generate_unique(name, existing_names, prefix = None):
    """Generate a unique name among existing ones by suffixing a number and adding a prefix

    Args:
        name: Input name
        existing_names: List of existing names
        prefix: Optional prefix to add to the output name

    Returns:
       Unique name with a number suffix in case of conflict, and an optional prefix

    """
    name = re.sub(r"[^\x00-\x7F]", "_", name).replace(
        " ", "_"
    )  # replace non ASCII and whitespace characters by an underscore _
    if prefix:
        new_name = f"{prefix}_{name}"
    else:
        new_name = name
    for j in range(1, 1001):
        if new_name not in existing_names:
            return new_name
        new_name = f"{new_name}_{j}"
    raise RuntimeError(f"Failed to generated a unique name for '{name}'")


def make_column_names_unique(df):
    cols = pd.Series(df.columns)
    for dup in cols[cols.duplicated()].unique():
        cols[cols[cols == dup].index.values.tolist()] = [
            dup + "_" + str(i) if i != 0 else dup for i in range(sum(cols == dup))
        ]
    df.columns = cols


def parse_one_route_row(row, get_itinerary):
    if row:
        response = json.loads(row)
        if get_itinerary:
            return pd.Series({"time": response.get("time"),
                              "distance": response.get("distance"),
                              "itinerary": response.get("itinerary")})
        else:
            return pd.Series({"time": response.get("time"),
                              "distance": response.get("distance")})
    else:
        if get_itinerary:
            return pd.Series({"time": None, "distance": None, "itinerary": None})
        else:
            return pd.Series({"time": None, "distance": None})


def parse_one_isochrone_row(row):
    if row:
        return pd.Series(json.loads(row))
    else:
        return pd.Series({"isochrone": None})


def apply_final_postprocessing(df, parsed_response):
    output_df = pd.concat([df, parsed_response], axis=1)
    output_df = output_df.drop("output_response", axis=1)
    error_columns = ["output_error_message","output_error_type","output_error_raw"]
    output_df = output_df[[column for column in output_df if column not in error_columns] + [column for column in error_columns if column in output_df]]
    make_column_names_unique(output_df)
    return output_df


def postprocess_routes_df(df,get_itinerary=None):
    if get_itinerary is not None:
        parsed_response = df["output_response"].apply(parse_one_route_row,args=[get_itinerary])
    else:
        parsed_response = df["output_response"].apply(parse_one_route_row)

    return apply_final_postprocessing(df, parsed_response)


def postprocess_isochrones_df(df):
    parsed_response = df["output_response"].apply(parse_one_isochrone_row)
    return apply_final_postprocessing(df, parsed_response)


def post_process_results(df, results, output_column_names):
    """Combines results from the function with the input dataframe"""
    output_schema = {
        **{column_name: str for column_name in output_column_names},
        **dict(df.dtypes),
    }
    output_df = (
        pd.DataFrame.from_records(results)
            .reindex(columns=list(df.columns) + list(output_column_names))
            .astype(output_schema)
    )
    return output_df


def postprocess_routes_batch(df, results, output_column_names, index_col_name, get_itinerary):
    results = list(flatten(results))
    results.sort(key=itemgetter(index_col_name))
    results = [{key: val for key, val in elt.items() if key != index_col_name} for elt in results]
    output_df = postprocess_routes_df(post_process_results(df, results, output_column_names), get_itinerary)
    return output_df


def postprocess_isochrones_batch(df, results, output_column_names, index_col_name):
    results = list(flatten(results))
    results.sort(key=itemgetter(index_col_name))
    results = [{key: val for key, val in elt.items() if key != index_col_name} for elt in results]
    output_df = postprocess_isochrones_df(post_process_results(df, results, output_column_names))
    return output_df


def parse_batch_response_default(batch, response, output_column_names):
    """Adds responses to each row dictionary in the batch, assuming the batch response is a list of responses
    in the same order as the batch, while keeping the existing row dictionary entries in the batch.

    Args:
        batch: Single input row from the dataframe as a dict in a list of length 1
        response: List of one or more responses returned by the API, typically a JSON string
        output_column_names: Column names to be added to the row,
            as defined in _get_unique_output_column_names

    Returns:
        batch: Same as input batch with additional columns
            corresponding to the default output columns
    """
    return [
        {
            output_column_names.response: response,
            output_column_names.error_message: "",
            output_column_names.error_type: "",
            output_column_names.error_raw: "",
            **row,
        }
        for response, row in zip(response, batch)
    ]


def apply_function_with_error_logging(batch, function, output_column_names, batch_support, **function_kwargs):
    """Wraps a row-by-row or batch function with error logging
    """
    output = deepcopy(batch)
    for output_column in output_column_names:
        for output_row in output:
            output_row[output_column] = ""
    try:
        if not batch_support:
            # In the row-by-row case, there is only one element in the list as batch_size=1
            response = [(function(row=batch[0], **function_kwargs))]
        else:
            response = function(batch=batch, **function_kwargs)
        output = parse_batch_response_default(
            batch=batch,
            response=response,
            output_column_names=output_column_names,
        )
        errors = [
            row[output_column_names.error_message]
            for row in output
            if row[output_column_names.error_message]
        ]
        if errors:
            raise BatchError(str(errors))

    except OperationalError:
        # Catching OperationalError thrown by cache issues - fail the job in this case
        raise Exception("Fatal error reading cache - It is possible that someone deleted the cache while this recipe was running.")
    except (Exception,) + (BatchError,) as error:
        error_type = str(type(error).__qualname__)
        module = inspect.getmodule(error)
        if module:
            error_type = f"{module.__name__}.{error_type}"
        for output_row in output:
            output_row[output_column_names.error_message] = str(error)
            output_row[output_column_names.error_type] = error_type
            output_row[output_column_names.error_raw] = str(error.args)
    return output


def get_unique_output_column_names(existing_names):
    """Returns a named tuple with prefixed column names and their descriptions"""
    OutputColumnNameTuple = namedtuple(
        "OutputColumnNameTuple", OUTPUT_COLUMN_NAME_DESCRIPTIONS.keys()
    )
    return OutputColumnNameTuple(
        *[
            generate_unique(
                name=column_name,
                existing_names=existing_names,
                prefix="output",
            )
            for column_name in OutputColumnNameTuple._fields
        ]
    )


def decode_encoded_polyline(encoded):
    inv = 1.0 / 1e6
    decoded = []
    previous = [0, 0]
    i = 0
    # for each byte
    while i < len(encoded):
        # for each coord (lat, lon)
        ll = [0, 0]
        for j in [0, 1]:
            shift = 0
            byte = 0x20
            # keep decoding bytes until you have this coord
            while byte >= 0x20:
                byte = ord(encoded[i]) - 63
                i += 1
                ll[j] |= (byte & 0x1f) << shift
                shift += 5
            # get the final value adding the previous offset and remember it for the next
            ll[j] = previous[j] + (~(ll[j] >> 1) if ll[j] & 1 else (ll[j] >> 1))
            previous[j] = ll[j]
        # scale by the precision and chop off long coords also flip the positions so
        # its the far more standard lon,lat instead of lat,lon
        decoded.append([float('%.6f' % (ll[1] * inv * 10)), float('%.6f' % (ll[0] * inv * 10))])
    # hand back the list of coordinates
    return decoded


def geojson_to_wkt(gjson):
    if isinstance(gjson, list):
        features = []
        for feature in gjson:
            features.append(shape(feature["geometry"]))
        geometry_collection = MultiPolygon(features)
    else:
        if not gjson["coordinates"]:
            return ""
        geometry_collection = shape(gjson)
    return geometry_collection.wkt


def raise_error_400(response, logger):
    json_response = {}
    try:
        if response.status_code == 400 and response.content:
            json_response = response.json()
    except Exception as err:
        logger.warning("Error parsing response:{}".format(err))
    raw_error_message = json_response.get("message")
    if raw_error_message:
        raise ValueError(improve_error_message(raw_error_message))


def improve_error_message(message):
    improved_message = message
    for message_sample in KNOWN_ERROR_MESSAGES:
        if message_sample in message:
            improved_message = "{} / {}".format(
                KNOWN_ERROR_MESSAGES.get(message_sample),
                message
            )
            return improved_message
    return improved_message


def log_debug_messages(logger, response, params, headers, query_method):
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug("state: " + str(response.status_code))
        logger.debug("params: " + str(params))
        logger.debug("headers: " + str(headers))
        logger.debug("query_method : " + str(query_method))


def get_local_transport_mode(global_transport_mode):
    if global_transport_mode == TRANSPORT_MODE.CAR:
        return "car"
    elif global_transport_mode == TRANSPORT_MODE.BIKE:
        return "bike"
    elif global_transport_mode == TRANSPORT_MODE.PEDESTRIAN:
        return "foot"
    else:
        raise ValueError(f"Invalid transport mode: {global_transport_mode}. Supported transport modes are : car, bicycle, pedestrian")

