from sklearn.neighbors import KDTree
import numpy as np
import pandas as pd
from shapely.geometry import Point, Polygon

from ..feature_engineering.information_extraction import (get_geo_points_collection_center,
                                                          merge_list_multipolygon_geo_points,
                                                          sort_neighbor_geo_points_indexes_by_geodesic_distance)


def geo_point_is_in_polygon(geo_point, polygon):
    shapely_point = Point( geo_point )
    shapely_geometry = Polygon( polygon )
    point_in_geojson = shapely_point.within( shapely_geometry )
    return point_in_geojson


def polygon_contains_geo_point(geometry_polygon, geo_point):
    shapely_geometry = Polygon( geometry_polygon )
    shapely_point = Point( geo_point )
    polygon_contains_point = shapely_geometry.contains( shapely_point )
    return polygon_contains_point


class PolygonsIndexer:

    def __init__(self, list_of_polygons, polygons_types):
        self.list_polygons = []
        self.list_polygons_types = []
        for polygon, polygon_type in zip(list_of_polygons, polygons_types):
            if polygon_type == "polygon":
                self.list_polygons.append(polygon)
            elif polygon_type == "multipolygon":
                self.list_polygons.append(merge_list_multipolygon_geo_points(polygon))
            else:
                log_message = "Polygons of type '{}' are not handled by this class."\
                              "Accepted types are 'polygon' and 'multipolygon'".format(polygon_type)
                raise Exception(log_message)
            self.list_polygons_types.append(polygon_type)
            pass
        self.n_polygons = len( list_of_polygons )
        self.polygons_centers = np.array( [get_geo_points_collection_center( polygon ) for polygon in list_of_polygons] )
        print("Indexing data...")
        self.kd_tree = KDTree(self.polygons_centers, leaf_size=100 )
        print( "Data indexed !" )
        pass

    def search_single_geo_point_belonging_polygons(self, geo_point, n_neighbors, n_successive_exclusions_stopping,
                                                   bool_search_multiple_polygons):
        n_neighbors = min( self.n_polygons, n_neighbors )
        distances, neighbor_polygons = self.kd_tree.query( np.array( [geo_point] ), k=n_neighbors)
        neighbor_polygon_indexes = list( neighbor_polygons[0] )
        geo_point_belonging_polygons = []
        n_exclusions = 0

        for polygon_index in neighbor_polygon_indexes:
            candidate_polygon = self.list_polygons[polygon_index]
            candidate_polygon_type = self.list_polygons_types[polygon_index]
            if candidate_polygon_type == "polygon":
                geo_point_included_in_polygon = geo_point_is_in_polygon( geo_point, candidate_polygon )

            elif candidate_polygon_type == "multipolygon":
                geo_point_inclusion_in_sub_polygons = []
                for sub_polygon in candidate_polygon:
                    geo_point_included_in_sub_polygon = geo_point_is_in_polygon(geo_point, sub_polygon)
                    geo_point_inclusion_in_sub_polygons.append(geo_point_included_in_sub_polygon)
                    pass

                if True in geo_point_inclusion_in_sub_polygons:
                    geo_point_included_in_polygon = True

            if geo_point_included_in_polygon:
                geo_point_belonging_polygons.append(polygon_index)
                n_exclusions = 0

            else:
                n_exclusions += 1

            if n_exclusions > n_successive_exclusions_stopping:
                break

        n_polygons_found = len( geo_point_belonging_polygons )
        if n_polygons_found == 0:
            geo_point_belonging_polygons = [None]
            pass

        if not bool_search_multiple_polygons:
            if n_polygons_found > 1:

                neighbor_polygons_geo_points = [self.polygons_centers[index]
                                                for index in geo_point_belonging_polygons]
                sorted_neighbors_indexes = \
                    sort_neighbor_geo_points_indexes_by_geodesic_distance( geo_point, neighbor_polygons_geo_points,
                                                                           False )
                closest_neighbor_index = sorted_neighbors_indexes[0]
                geo_point_belonging_polygons = [geo_point_belonging_polygons[closest_neighbor_index]]

        return geo_point_belonging_polygons, n_polygons_found

    def search_geo_points_belonging_polygons(self, geo_points, n_neighbors, n_successive_exclusions_stopping,
                                             bool_search_multiple_polygons):
        geo_points_sucessfully_assigned = 0
        df_geo_points_polygons = pd.DataFrame(columns=["geo_point_index", "included_in_polygon_index"] )
        n_geo_points = len( geo_points )

        for geo_point_index, geo_point in enumerate( geo_points ):
            geo_point_belonging_polygons, n_polygons_found = self.search_single_geo_point_belonging_polygons(
                geo_point, n_neighbors, n_successive_exclusions_stopping, bool_search_multiple_polygons)

            if n_polygons_found == 0:
                '''df_geo_points_polygons = df_geo_points_polygons.append(
                    {"geo_point_index": geo_point_index, "included_in_polygon_index": None}, ignore_index=True)'''
                df_geo_points_polygons = pd.concat([df_geo_points_polygons, pd.DataFrame
                                                    ([{"geo_point_index": geo_point_index, "included_in_polygon_index": polygon_index}])], 
                                                   ignore_index=True)

                               
            else:
                geo_points_sucessfully_assigned += 1
                for polygon_index in geo_point_belonging_polygons:
                    '''df_geo_points_polygons = df_geo_points_polygons.append(
                        {"geo_point_index": geo_point_index, "included_in_polygon_index": polygon_index},
                        ignore_index=True)'''
                df_geo_points_polygons = pd.concat([df_geo_points_polygons, pd.DataFrame
                                                    ([{"geo_point_index": geo_point_index, "included_in_polygon_index": polygon_index}])], 
                                                   ignore_index=True)
                
            log_message = "geo_point inclusion in polygons checked ({}/{}).".format(geo_point_index + 1,
                                                                                    n_geo_points)
            log_message += " Success assignment ratio : {}".format(geo_points_sucessfully_assigned/(geo_point_index+1))
            print(log_message)
            pass
        return df_geo_points_polygons

    pass


class GeoPointsIndexer:

    def __init__(self, list_of_geo_points):
        """
        :param list_of_geo_points: List of geo points in [latitude, longitude] or (latitude, longitude) format
        """
        self.list_of_geo_points = list_of_geo_points
        self.n_geo_points = len( list_of_geo_points )
        geo_points_data = np.array([geo_point for geo_point in list_of_geo_points])
        #np.array( [get_geo_points_collection_center( polygon ) for polygon in list_of_polygons] )
        print("Indexing data...")
        self.kd_tree = KDTree( geo_points_data, leaf_size=100 )
        print( "Data indexed !" )
        pass

    def search_single_geo_point_neighbors(self, geo_point, n_neighbors_to_search, n_neighbors_to_retrieve,
                                          bool_sort_results_by_geodesic_distance):
        """
        :param geo_point: Geo point in [latitude, longitude] or (latitude, longitude) format
        :param n_neighbors_to_search: Number of neighbor geo points to search
        :param n_neighbors_to_retrieve: Number of final neighbor geo points to retrieve
        :param bool_sort_results_by_geodesic_distance:
        :return:
        """
        n_neighbors_to_search = min( self.n_geo_points, n_neighbors_to_search )
        n_neighbors_to_retrieve = min( n_neighbors_to_retrieve, n_neighbors_to_search )
        distances, neighbor_geo_points = self.kd_tree.query( np.array( [geo_point] ), k=n_neighbors_to_search)
        neighbor_geo_points_indexes = list( neighbor_geo_points[0] )

        if bool_sort_results_by_geodesic_distance:
            neighbor_geo_points = [self.list_of_geo_points[index] for index in neighbor_geo_points_indexes]
            """
            if bool_sort_results_by_geodesic_distance:
            neighbors_geo_points = [self.list_of_geo_points[id] for id in neighboring_geo_points_ids]
            reverse_coordinates = False
            neighbors_distances = [get_geodesic_distance(geo_point, neighbor_geo_point, reverse_coordinates)
                                   for neighbor_geo_point in neighbors_geo_points]
            sorted_neighbors_ids = np.argsort(neighbors_distances)
            neighboring_geo_points_ids = [neighboring_geo_points_ids[id] for id in sorted_neighbors_ids]

            return neighboring_geo_points_ids
            """
            reverse_coordinates = False
            sorted_neighbor_geo_points_indexes = sort_neighbor_geo_points_indexes_by_geodesic_distance(geo_point,
                                                                                                neighbor_geo_points,
                                                                                                reverse_coordinates)
            neighbor_geo_points_indexes = [neighbor_geo_points_indexes[index]
                                           for index in sorted_neighbor_geo_points_indexes]
        return neighbor_geo_points_indexes[0:n_neighbors_to_retrieve]

    def search_geo_points_neighbors(self, geo_points, n_neighbors_to_search, n_neighbors_to_retrieve,
                                    bool_sort_results_by_geodesic_distance):
        geo_points_neighbors = []
        n_geo_points = len(geo_points)
        for geo_point_index, geo_point in enumerate(geo_points):
            neighbor_geo_points_indexes = self.search_single_geo_point_neighbors(geo_point, n_neighbors_to_search,
                                                                                 n_neighbors_to_retrieve,
                                                                                 bool_sort_results_by_geodesic_distance)
            geo_points_neighbors.append(neighbor_geo_points_indexes)
            log_message = "geo_point ({}/{}) neighbors found.".format( geo_point_index + 1,
                                                                       n_geo_points )
            print(log_message)
        return geo_points_neighbors

    pass


def search_geo_points_belonging_polygons(geo_points, polygons):
    df_geo_points_polygons = pd.DataFrame(columns=["geo_point_index", "included_in_polygon_index"] )
    n_combinations = len( geo_points ) * len( polygons )
    loop_index = 0

    for geo_point_index, geo_point in enumerate( geo_points ):

        for polygon_index, polygon in enumerate( polygons ):

            geo_point_included_in_polygon = geo_point_is_in_polygon( geo_point, polygon )

            if geo_point_included_in_polygon:
                df_geo_points_polygons = df_geo_points_polygons.append(
                    {"geo_point_index": geo_point_index, "included_geo_point_index": polygon_index}, ignore_index=True )

            loop_index += 1
            print( f"geo_point inclusion in polygons checked ({loop_index}/{n_combinations})" )
            pass
        pass

    return df_geo_points_polygons