import pandas as pd
from typing import Tuple
import numpy as np
from shapely import geometry
from .collector import Collector


class Writer(Collector):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._input_list = []
        self._output_list = []

    def write_pois(self, datas: dict, df: pd.DataFrame):
        if self.request_by_batch:
            return self._batch_write_pois(datas, df)
        else:
            return self._row_write_pois(datas, df)

    def _batch_write_pois(self, datas: dict, df: pd.DataFrame) -> Tuple[list, list]:
        """For each point of interest of a geometry, store:
        -the initial row of the dataset which contains the geometry in 'input_list'
        -the enrichments informations from the geometry in 'output_list',with:
            -values of the filter for each filter in self.filters
            -tags associated to the point of interest
            -geopoint of the point of interest
            -failure status code and message, if any

        Args:
            datas (dict): API response for querying point of interests for the geometries in 'df'
            df (pandas.DataFrame): dataframe which contains the geometries we had query on

        Returns:
            input_list: Initial rows with as many duplicates as there are point of interests in the geometry
            output_list : rows composed of the enrichments, as defined above
        """
        batch_idx = 0
        if not datas["elements"]:  # if a client/server error occured
            self._output_list = [[np.nan] * (len(self.keys) + 2) + [datas["failure_response"]]] * len(df)
            self._input_list = [row for row in df.values.tolist()]
        else:
            self._input_list = [] # only a batch is meant to be processed at this stage
            self._output_list = []
            for element in datas["elements"]:
                if element["type"] == "node":
                    self._output_list += [
                        self._extract_tags(element["tags"])
                        + [
                            element["tags"],
                            geometry.Point(element["lon"], element["lat"]),
                            datas["failure_response"],
                        ]
                    ]
                    self._input_list.extend([df.iloc[batch_idx].tolist()])
                elif element["type"] == "count":
                    if element["tags"]["total"] == str(0):
                        self._output_list.extend([[np.nan] * (len(self.keys) + 2) + [datas["failure_response"]]])
                        self._input_list.extend([df.iloc[batch_idx].tolist()])
                    batch_idx += 1
        return self._input_list, self._output_list

    def _row_write_pois(self, datas, df):
        self._input_list = [] # only a batch is meant to be processed at this stage
        self._output_list = []
        for index, row in df.iterrows():
            if not datas[index]["elements"]:
                self._output_list.extend([[np.nan] * (len(self.keys) + 2) + [datas[index]["failure_response"]]])
                self._input_list.extend([row.tolist()])
            else:
                for element in datas[index]["elements"]:
                    if element["type"] == "node":
                        self._output_list += [
                            self._extract_tags(element["tags"])
                            + [
                                element["tags"],
                                geometry.Point(element["lon"], element["lat"]),
                                datas[index]["failure_response"],
                            ]
                        ]
                        self._input_list.extend([row.tolist()])
                    elif element["type"] == "count":
                        if element["tags"]["total"] == str(0):
                            self._output_list.extend(
                                [[np.nan] * (len(self.keys) + 2) + [datas[index]["failure_response"]]]
                            )
                            self._input_list.extend([row.tolist()])
        return self._input_list, self._output_list
