import re
from ast import literal_eval
from functools import partial
import shapely
from shapely.geometry import Point, Polygon
from shapely.ops import transform
import geopy.distance
import numpy as np
import copy
import pyproj
from .config import GEOJSON_POLYGON_STARTER
from sklearn.neighbors import KDTree


def read_wkt_geo_point(geo_point):
    geo_point = re.sub("POINT {0,1}|\(|\)","", geo_point)
    geo_point = geo_point.split(" ")
    geo_point = [float(coordinate) for coordinate in geo_point]
    geo_point = tuple(geo_point)
    return geo_point

def create_wkt_geo_point(latitude, longitude):
    return "POINT({} {})".format(longitude, latitude)
    
def read_geo_point(geo_point):
    geo_point = str(geo_point)
    try:
        if 'point' in str(geo_point).lower():
            geo_point = read_wkt_geo_point(geo_point)
        else:
            geo_point = literal_eval(geo_point)
        return geo_point
    except ValueError:
        return 'ValueError : geo_point is not readable'

def read_as_shapely_polygon(polygon):
    if not isinstance(polygon, shapely.geometry.polygon.Polygon):
        polygon = Polygon(polygon)
    return polygon

def read_as_regular_polygon(shapely_polygon):
    longitudes, latitudes = shapely_polygon.exterior.coords.xy
    return from_coordinates_to_polygon(longitudes, latitudes)

def geo_point_is_in_geojson(geo_point, geojson):
    shapely_point = Point(geo_point)
    shapely_geometry = Polygon(geojson["geometry"]["coordinates"][0])
    point_in_geojson = shapely_point.within(shapely_geometry)
    return point_in_geojson

def polygon_contains_geo_point(geometry_polygon, geo_point):
    shapely_geometry = Polygon(geometry_polygon)
    shapely_point = Point(geo_point)
    polygon_contains_point = shapely_geometry.contains(shapely_point)
    return polygon_contains_point

def compute_polygon_envelope(polygon):
    polygon = Polygon(polygon)
    envelope = polygon.envelope
    longitudes, latitudes = envelope.exterior.coords.xy
    return longitudes, latitudes

def extract_polygon_envelope_metadata(polygon):
    longitudes, latitudes = compute_polygon_envelope(polygon)
    min_lat = np.min(latitudes)
    min_lon = np.min(longitudes)
    max_lat = np.max(latitudes)
    max_lon = np.max(longitudes)
    envelope_data = {"min_lat":min_lat,
                     "min_lon":min_lon,
                     "max_lat":max_lat,
                     "max_lon":max_lon}

    envelope_data_str = "({},{},{},{})".format(min_lat,
                                               min_lon,
                                               max_lat,
                                               max_lon)
    # get length of the envelope edges :
    min_lon_min_lat_corner = [min_lon, min_lat]
    max_lon_min_lat_corner = [max_lon, min_lat]
    min_lon_max_lat_corner = [min_lon, max_lat]
    envelope_edges_lengths = [compute_geodesic_distance(min_lon_min_lat_corner, max_lon_min_lat_corner, True),
                              compute_geodesic_distance(min_lon_min_lat_corner, min_lon_max_lat_corner, True)]

    # get length of the envelope as the longest edge
    envelope_length = max(envelope_edges_lengths)

    # get width of the envelope as the shortest edge
    envelope_width = min(envelope_edges_lengths)
    enveloppe_metadata = {
        "envelope_data": envelope_data,
        "envelope_data_str": envelope_data_str,
        "envelope_length": envelope_length,
        "envelope_width": envelope_width
                         }
    return enveloppe_metadata

def compute_geo_points_collection_center(geo_points_collection):
    """
    geo points formats :
    - (longitude, latitude)
    - [longitude, latitude]
    """
    n_geo_points = len(geo_points_collection)
    longitudes = []
    latitudes = []
    for geo_point in geo_points_collection:
        longitudes.append(geo_point[0])
        latitudes.append(geo_point[1])
        
    geo_points_center = [np.mean(longitudes), np.mean(latitudes)]
    return geo_points_center


def compute_geodesic_distance(geo_point_1, geo_point_2, reverse_coordinates):
    """
    geo_point_1, geo_point_2 formats :
    - (latitude, longitude)
    - [latitude, longitude]

    --> distance in km
    """
    try:
        if reverse_coordinates:
            distance = geopy.distance.distance(reverse_geo_point_coordinates(geo_point_1), reverse_geo_point_coordinates(geo_point_2)).km
        else:
            distance = geopy.distance.distance(geo_point_1, geo_point_2).km
        return distance
    
    except:
        return None
    pass

def reverse_geo_point_coordinates(geo_point):
    """
    geo_point : 'tuple' or list 
    """
    return [geo_point[1], geo_point[0]]

def reverse_polygon_coordinates(polygon):
    """
    polygon must have format :
    [[param_1_1, param_2_1], [param_1_2, param_2_2], ...., [param_1_n, param_2_n]]
    """
    new_coordinates = [reverse_geo_point_coordinates(coordinates) for coordinates in polygon]
    return new_coordinates

def from_polygon_coordinates_to_overpass_geometry(polygon):
    """
    polygon must have format :
    [[longitude_1, latitude_1], [longitude_2, latitude_2], ...., [longitude_n, latitude_n]]
    """
    polygon = reverse_polygon_coordinates(polygon)
    overpass_parameters = str(polygon)
    overpass_parameters = re.sub("[\[\],]", "", overpass_parameters)
    overpass_parameters = '(poly:"{}")'.format(overpass_parameters)
    return overpass_parameters

def from_polygon_coordinates_to_geojson(polygon):
    geojson = copy.deepcopy(GEOJSON_POLYGON_STARTER)
    for coordinates in polygon:
        geojson["geometry"]["coordinates"][0].append(list(coordinates))
    return geojson

def from_coordinates_to_polygon(longitudes, latitudes):
    return [[lon, lat] for lon, lat in zip(longitudes, latitudes)]

def from_geojson_to_polygon_coordinates(geojson):
    return geojson["geometry"]["coordinates"][0]

def compute_circular_polygon(circle_center_lon, circle_center_lat, circle_radius_m):

    local_azimuthal_projection = "+proj=aeqd +R=6371000 +units=m +lat_0={} +lon_0={}".format(
        circle_center_lat, circle_center_lon
    )
    wgs84_to_aeqd = partial(
        pyproj.transform,
        pyproj.Proj("+proj=longlat +datum=WGS84 +no_defs"),
        pyproj.Proj(local_azimuthal_projection),
    )
    aeqd_to_wgs84 = partial(
        pyproj.transform,
        pyproj.Proj(local_azimuthal_projection),
        pyproj.Proj("+proj=longlat +datum=WGS84 +no_defs"),
    )

    center = Point(float(circle_center_lon), float(circle_center_lat))
    point_transformed = transform(wgs84_to_aeqd, center)
    buffer = point_transformed.buffer(circle_radius_m)

    shapely_circular_polygon = transform(aeqd_to_wgs84, buffer)
    circular_polygon = read_as_regular_polygon(shapely_circular_polygon)
    return circular_polygon

class PolygonsIndexer:
    
    def __init__(self, polygons_centers):
        polygons_centers = np.array(polygons_centers)
        n_polygons = len(polygons_centers)
        self.kd_tree = KDTree(polygons_centers, leaf_size=100)
        pass
    
    def search_geo_point_belonging_polygons(self, geo_point, polygons_geojsons, n_neighbors, successive_exclusions_stopping):
        distances, neighboring_polygons = self.kd_tree.query(np.array([geo_point]), k=n_neighbors)
        neighboring_polygon_ids = list(neighboring_polygons[0])
        geo_point_belonging_polygons = []
        n_exclusions = 0

        for polygon_id in neighboring_polygon_ids:
            polygon_geojson = polygons_geojsons[polygon_id]
            geo_point_included_in_geojson = geo_point_is_in_geojson(geo_point, polygon_geojson)
            
            if geo_point_included_in_geojson:
                geo_point_belonging_polygons.append(polygon_id)
                n_exclusions = 0
                pass
            else:
                n_exclusions += 1
                pass
            if n_exclusions > successive_exclusions_stopping:
                break
            
        n_polygons_found = len(geo_point_belonging_polygons)
        return geo_point_belonging_polygons, n_polygons_found


def convert_from_geojson_to_list_polygon(geojson):
    return geojson["geometry"]["coordinates"][0]


def convert_polygon_from_list_to_shapely(list_polygon, reverse_coordinates):
    if reverse_coordinates:
        list_polygon = reverse_polygon_coordinates( list_polygon )
    shapely_polygon = Polygon(list_polygon)
    return shapely_polygon


def convert_polygon_from_shapely_to_wkt_string(shapely_polygon, reverse_coordinates):
    if reverse_coordinates:
        shapely_polygon = transform(flip, shapely_polygon)
    wkt_polygon_string = shapely_polygon.wkt
    return wkt_polygon_string