import dataiku
from commons.python.fetch.config_bs import ConfigBs, EnvMode
import pandas as pd
from .setup_logger import logger
from io import StringIO
from operator import itemgetter
from typing import List, Dict, Any, Optional, Union


class SdohConfig:
    LOCAL_PROJECT_KEY = "SOL_SDOH"
    TRACTS_GEOMETRY_FOLDER = "tracts_data"
    COUNTIES_GEOMETRY_FILE = "counties.json"
    TRACTS_GEOMETRY_FILE = "tracts_data_complete.json"
    DISEASE_LIST_DATASET = "svi_vulnerability_cdc_prepared_distinct_stacked"
    TRACTS_DATASET = "new_measure_final_dataset_tract"
    COUNTY_TRACT_DATASET = "new_measure_tract_aggregate_county"
    COUNTY_DATASET = "new_measure_final_dataset_county"
    STATE_TO_COUNTY_FOLDER = "state_to_county"
    STATE_CODES_LOOKUP_CSV = "state_codes_lookup.csv"

    ## Hard coded, TODO: GET THEMES SOMEWHERE
    THEMES = [
        "Housing Type and Transportation",
        "Racial and Ethnic Minority Status",
        "Household Characteristics",
        "Socioeconomic Status",
    ]

    SOCIAL_VULNERABILITY_INDEX = "Social Vulnerability Index"
    POPULATION_COL_TRACT = "Population"
    POPULATION_DENSITY_COL_TRACT = "Population_Density_tract"
    POPULATION_COL_COUNTY = "Population_county"
    POPULATION_DENSITY_COL_COUNTY = "Population_Density_county"
    POPULATION_COLS_TRACT = [POPULATION_COL_TRACT, POPULATION_DENSITY_COL_TRACT]
    POPULATION_COLS_COUNTY = [POPULATION_COL_COUNTY, POPULATION_DENSITY_COL_COUNTY]
    POPULATION_COLS_TRACT_COUNTY = [
        POPULATION_COL_COUNTY,
        POPULATION_DENSITY_COL_COUNTY,
    ]

    ## REPETITIVE, BUT CAN BE EASY TO CONFIG IF CHANGED
    OTHER_COLS_TRACT = ["FIPS", "cluster_labels"]
    OTHER_COLS_COUNTY = ["FIPS", "cluster_labels"]
    OTHER_COLS_TRACT_COUNTY = ["FIPS", "cluster_labels"]


class SdohPreprocessor:
    @classmethod
    def format_string(cls, s: str):
        result = s.strip()
        return result

    @classmethod
    def get_disease_columns_from_schema(
        cls, diseases: List[str], schema: List[Dict[str, str]]
    ):
        ## Computes a maaping between disease names and columns name from schema
        # For unconsistencies in naming variables in the flow
        # Return the not found diseases
        column_names = [val.get("name") for val in schema]
        result = {}
        for disease in diseases:
            disease_cols = {"percent": None, "percentile": None}
            for column_name in column_names:
                if f"Percent {disease}" in column_name:
                    disease_cols["percent"] = column_name
                if column_name.endswith("Percentile") and disease in column_name:
                    disease_cols["percentile"] = column_name
            if disease_cols["percent"] and disease_cols["percentile"]:
                result[disease] = disease_cols
        return result

    @classmethod
    def get_percentile_theme_cols_from_schema(
        cls, themes: List[str], schema: List[Dict[str, str]]
    ):
        ## Computes a maping between theme names and percentile columns name from schema
        # For unconsistencies in naming variables in the flow
        # Return the not found themes
        column_names = [val.get("name") for val in schema]
        result = {}
        for theme in themes:
            for column_name in column_names:
                if column_name.endswith("Percentile") and theme in column_name:
                    result[theme] = column_name

        return result

    @classmethod
    def get_social_index_col(cls, schema: List[Dict[str, str]]):
        for col in schema:
            if SdohConfig.SOCIAL_VULNERABILITY_INDEX in col["name"]:
                return col["name"]
        return None

    @classmethod
    def get_population_columns(cls, schema: List[Dict[str, str]], mode: str):
        names = (
            SdohConfig.POPULATION_COLS_COUNTY
            if mode == "county"
            else (
                SdohConfig.POPULATION_COLS_TRACT
                if mode == "tract"
                else SdohConfig.POPULATION_COLS_TRACT_COUNTY
            )
        )
        result = []
        for name in names:
            found = False
            for col in schema:
                if name == col["name"]:
                    found = True
                    result.append(name)
            if not found:
                raise ValueError(f"{name} column not found in {mode} Dataset")
        return result

    @classmethod
    def get_other_columns(cls, schema: List[Dict[str, str]], mode: str):
        names = (
            SdohConfig.OTHER_COLS_COUNTY
            if mode == "county"
            else (
                SdohConfig.OTHER_COLS_TRACT
                if mode == "tract"
                else SdohConfig.OTHER_COLS_TRACT_COUNTY
            )
        )
        result = []
        for name in names:
            found = False
            for col in schema:
                if name == col["name"]:
                    found = True
                    result.append(name)
            if not found:
                raise ValueError(f"{name} column not found in {mode} Dataset")
        return result


class DataikuClient:
    def __init__(self):
        self.__client = dataiku.api_client()
        self.__mode = ConfigBs.mode()
        self.__project_key = (
            SdohConfig.LOCAL_PROJECT_KEY
            if self.__mode == EnvMode.LOCAL.value
            else dataiku.get_custom_variables()["projectKey"]
        )
        self.__project = self.__client.get_project(self.__project_key)

    @property
    def project(self):
        return self.__project

    @property
    def project_key(self):
        return self.__project_key

    @property
    def client(self):
        return self.__client


class SDOH(object):
    def __init__(self, dataiku_client: DataikuClient):
        self.project_key = dataiku_client.project_key

        self.counties_geometry = dataiku.Folder(
            SdohConfig.TRACTS_GEOMETRY_FOLDER, project_key=self.project_key
        ).read_json(SdohConfig.COUNTIES_GEOMETRY_FILE)
        logger.info("Counties geometry loaded")

        self.tracts_geometry = dataiku.Folder(
            SdohConfig.TRACTS_GEOMETRY_FOLDER, project_key=self.project_key
        ).read_json(SdohConfig.TRACTS_GEOMETRY_FILE)
        logger.info("Tracts geometry loaded")

        self.tract_dataset = dataiku.Dataset(
            SdohConfig.TRACTS_DATASET, project_key=self.project_key
        )
        self.county_tract_dataset = dataiku.Dataset(
            SdohConfig.COUNTY_TRACT_DATASET, project_key=self.project_key
        )
        self.county_dataset = dataiku.Dataset(
            SdohConfig.COUNTY_DATASET, project_key=self.project_key
        )

        self.tract_dataset_schema = self.tract_dataset.read_schema()
        self.county_tract_dataset_schema = self.county_tract_dataset.read_schema()
        self.county_dataset_schema = self.county_dataset.read_schema()
        logger.info("Datasets handles created")

        self.disease_map = self.get_disease_map()
        logger.info("Disease list loaded & formated")

        self.themes_map = self.get_themes_map()
        logger.info("Themes list loaded & formated")

        self.__state_counties = self.get_state_counties()
        logger.info("State counties object generated")

        self.cache = {}

    @property
    def state_counties(self):
        return self.__state_counties

    @property
    def diseases(self):
        return list(self.disease_map.values())

    @property
    def themes(self):
        return list(self.themes_map.values())

    @property
    def themes_client(self):
        return [theme["name"] for theme in self.themes]

    @property
    def diseases_client(self):
        return [
            {
                "value": disease["name"],
                "label": disease["name"],
                "isTract": disease["is_tract"],
            }
            for disease in self.diseases
        ]

    def get_disease_map(self):
        ## Should filter out diseases not found in columns
        dataset_disease = dataiku.Dataset(
            SdohConfig.DISEASE_LIST_DATASET, project_key=self.project_key
        )
        df_disease = dataset_disease.get_dataframe()
        disease_list = df_disease.to_dict(orient="records")
        disease_list = [
            {
                "name": SdohPreprocessor.format_string(disease["Health Reason"]),
                "is_tract": disease["Area_Level"] == "tract",
            }
            for disease in disease_list
        ]
        tracts_mapping = SdohPreprocessor.get_disease_columns_from_schema(
            [disease["name"] for disease in disease_list if disease["is_tract"]],
            self.tract_dataset_schema,
        )
        tracts_county_mapping = SdohPreprocessor.get_disease_columns_from_schema(
            [disease["name"] for disease in disease_list if disease["is_tract"]],
            self.county_tract_dataset_schema,
        )
        county_mapping = SdohPreprocessor.get_disease_columns_from_schema(
            [disease["name"] for disease in disease_list if not disease["is_tract"]],
            self.county_dataset_schema,
        )
        result = {}
        for disease in disease_list:
            if disease["is_tract"]:
                if tracts_mapping.get(disease["name"]) and tracts_county_mapping.get(
                    disease["name"]
                ):
                    result[disease["name"]] = {
                        "name": disease["name"],
                        "is_tract": True,
                        "tract_columns": tracts_mapping.get(disease["name"]),
                        "county_columns": tracts_county_mapping.get(disease["name"]),
                    }

            else:
                if county_mapping.get(disease["name"]):
                    result[disease["name"]] = {
                        "name": disease["name"],
                        "is_tract": False,
                        "tract_columns": None,
                        "county_columns": county_mapping.get(disease["name"]),
                    }
        return result

    def get_state_counties(self):
        """
        Creates an obj = {"state_code" : { "state_code" : code, "state_name" : name, counties: [{"county_code": code, "county_name": name} ...]} ...}
        """
        state_to_county_path = dataiku.Folder(
            SdohConfig.STATE_TO_COUNTY_FOLDER, project_key=self.project_key
        )
        with state_to_county_path.get_download_stream(
            SdohConfig.STATE_CODES_LOOKUP_CSV
        ) as f:
            data = f.read()

        state_codes_lookup = pd.read_csv(StringIO(data.decode("utf-8")), index_col=0)
        state_codes_lookup["State_code"] = (
            state_codes_lookup["State_code"].astype(int).astype(str).str.zfill(2)
        )
        states_names = {
            state_codes_lookup["State_code"][i]: state_codes_lookup["State_name"][i]
            for i in range(len(state_codes_lookup["State_code"]))
        }
        obj = {}
        for item in self.counties_geometry["features"]:
            state_code = item["properties"]["STATEFP"]
            county_code = item["properties"]["COUNTYFP"]
            county_name = (
                item["properties"]["NAMELSAD"]
                if not item["properties"]["NAMELSAD"].endswith("County")
                else item["properties"]["NAMELSAD"].replace("County", "").strip()
            )
            if states_names.get(state_code):
                if state_code in obj:
                    obj[state_code]["counties"][county_code] = {
                        "county_code": county_code,
                        "county_name": county_name,
                    }
                else:
                    obj[state_code] = {
                        "state_code": state_code,
                        "state_name": states_names[state_code],
                        "counties": {},
                    }
                    obj[state_code]["counties"][county_code] = {
                        "county_code": county_code,
                        "county_name": county_name,
                    }

        return obj

    def get_themes_map(self):
        tracts_mapping = SdohPreprocessor.get_percentile_theme_cols_from_schema(
            SdohConfig.THEMES, self.tract_dataset_schema
        )
        tracts_county_mapping = SdohPreprocessor.get_percentile_theme_cols_from_schema(
            SdohConfig.THEMES, self.county_tract_dataset_schema
        )
        county_mapping = SdohPreprocessor.get_percentile_theme_cols_from_schema(
            SdohConfig.THEMES, self.county_dataset_schema
        )
        social_index_col_tract = SdohPreprocessor.get_social_index_col(
            self.tract_dataset_schema
        )
        social_index_col_tract_county = SdohPreprocessor.get_social_index_col(
            self.county_tract_dataset_schema
        )
        social_index_col_county = SdohPreprocessor.get_social_index_col(
            self.county_dataset_schema
        )
        result = {}
        if (
            social_index_col_tract
            and social_index_col_tract_county
            and social_index_col_county
        ):
            result[SdohConfig.SOCIAL_VULNERABILITY_INDEX] = {
                "name": SdohConfig.SOCIAL_VULNERABILITY_INDEX,
                "tract_column": social_index_col_tract,
                "tract_county_column": social_index_col_tract_county,
                "county_column": social_index_col_county,
            }
        for theme in SdohConfig.THEMES:
            if (
                tracts_mapping.get(theme)
                and tracts_county_mapping.get(theme)
                and county_mapping.get(theme)
            ):
                result[theme] = {
                    "name": theme,
                    "tract_column": tracts_mapping.get(theme),
                    "tract_county_column": tracts_county_mapping.get(theme),
                    "county_column": county_mapping.get(theme),
                }

        return result

    def is_valid_mode(self, mode: str):
        if mode in ["tract", "county", "tract_county"]:
            return True
        return False

    def get_schema_from_mode(self, mode: str):
        if not self.is_valid_mode(mode):
            raise ValueError(f"Mode {mode} not supported")
        return (
            self.tract_dataset_schema
            if mode == "tract"
            else (
                self.county_dataset_schema
                if mode == "county"
                else self.county_tract_dataset_schema
            )
        )

    def get_dataset_from_mode(self, mode: str):
        if not self.is_valid_mode(mode):
            raise ValueError(f"Mode {mode} not supported")
        return (
            self.tract_dataset
            if mode == "tract"
            else (
                self.county_dataset if mode == "county" else self.county_tract_dataset
            )
        )

    def get_dataset_columns(self, mode: str, filter_disease: Optional[str] = None):
        disease_key = (
            "county_columns"
            if (mode == "county" or mode == "tract_county")
            else "tract_columns"
        )
        themes_key = (
            "county_column"
            if mode == "county"
            else ("tract_column" if mode == "tract" else "tract_county_column")
        )
        schema_from_mode = self.get_schema_from_mode(mode=mode)
        if filter_disease:
            if self.disease_map.get(filter_disease):
                disease_percent_cols = [
                    self.disease_map.get(filter_disease)[disease_key]["percent"]
                ]
                disease_percentile_cols = [
                    self.disease_map.get(filter_disease)[disease_key]["percentile"]
                ]
            else:
                raise ValueError(f"Disease {filter_disease} not found")
        else:
            disease_percent_cols = [
                disease[disease_key]["percent"]
                for disease in self.diseases
                if not disease["is_tract"]
            ]
            disease_percentile_cols = [
                disease[disease_key]["percentile"]
                for disease in self.diseases
                if not disease["is_tract"]
            ]

        theme_cols = [theme[themes_key] for theme in self.themes]
        population_cols = SdohPreprocessor.get_population_columns(
            schema_from_mode, mode
        )
        other_cols = SdohPreprocessor.get_other_columns(schema_from_mode, mode)

        return (
            other_cols
            + population_cols
            + disease_percent_cols
            + disease_percentile_cols
            + theme_cols
        )

    def preprocess(self, df: pd.DataFrame, mode: str, disease: str):
        if not self.is_valid_mode(mode):
            raise ValueError(f"Mode {mode} not supported")

        df = df.dropna()
        if mode == "tract":
            df = df.dropna()
            df["FIPS"] = df["FIPS"].astype(int).astype(str).str.zfill(11)
            df["State_code"] = df["FIPS"].apply(lambda x: x[0:2])
            df["County_code"] = df["FIPS"].apply(lambda x: x[2:5])
            ## Cols renaming disease
            df = df.rename(
                columns={
                    self.disease_map[disease]["tract_columns"][
                        "percent"
                    ]: "disease_percent",
                    self.disease_map[disease]["tract_columns"][
                        "percentile"
                    ]: "disease_percentile",
                }
            )
            ## Cols renaming theme
            df = df.rename(
                columns={theme["tract_column"]: theme["name"] for theme in self.themes}
            )
            ## Raname population cols
            df = df.rename(
                columns={
                    SdohConfig.POPULATION_COL_TRACT: "population",
                    SdohConfig.POPULATION_DENSITY_COL_TRACT: "population_density",
                }
            )
        else:
            df["FIPS"] = df["FIPS"].astype(int).astype(str).str.zfill(5)
            df["State_code"] = df["FIPS"].apply(lambda x: x[0:2])
            df["County_code"] = df["FIPS"].apply(lambda x: x[2:])

            df = df.rename(
                columns={
                    self.disease_map[disease]["county_columns"][
                        "percent"
                    ]: "disease_percent",
                    self.disease_map[disease]["county_columns"][
                        "percentile"
                    ]: "disease_percentile",
                }
            )
            if mode == "county":
                df = df.rename(
                    columns={
                        theme["county_column"]: theme["name"] for theme in self.themes
                    }
                )
            else:
                df = df.rename(
                    columns={
                        theme["tract_county_column"]: theme["name"]
                        for theme in self.themes
                    }
                )
            df = df.rename(
                columns={
                    SdohConfig.POPULATION_COL_COUNTY: "population",
                    SdohConfig.POPULATION_DENSITY_COL_COUNTY: "population_density",
                }
            )

        df = df.drop_duplicates(subset="FIPS", keep="first")
        return df

    def get_filtered_df(self, mode: str, state: str, county: str, disease: str):
        dataset = self.get_dataset_from_mode(mode=mode)
        columns = self.get_dataset_columns(mode=mode, filter_disease=disease)
        if state == "all":
            df = self.preprocess(
                dataset.get_dataframe(columns=columns), mode=mode, disease=disease
            )
        else:
            df: Optional[pd.DataFrame] = None
            for df_ in dataset.iter_dataframes(columns=columns):
                filtered_df_ = self.preprocess(df_, mode=mode, disease=disease)
                filtered_df_ = filtered_df_[filtered_df_["State_code"] == state]
                if county != "all":
                    filtered_df_ = filtered_df_[filtered_df_["County_code"] == county]
                if df is None:
                    df = filtered_df_
                else:
                    df = pd.concat([df, filtered_df_])

        return df

    def get_geometries(self, state: str, county: str):
        if county == "all":
            if state == "all":
                return self.counties_geometry
            else:
                return {
                    "type": "FeatureCollection",
                    "features": [
                        item
                        for item in self.counties_geometry["features"]
                        if item["properties"]["STATEFP"] == state
                    ],
                }
        else:
            return {
                "type": "FeatureCollection",
                "features": [
                    item
                    for item in self.tracts_geometry["features"]
                    if item["properties"]["STATEFP"] == state
                    and item["properties"]["COUNTYFP"] == county
                ],
            }

    def get_data(self, meta: Dict[str, Any]):
        state, county, disease_map = itemgetter("state", "county", "disease")(meta)
        disease = disease_map["value"]
        is_tract_disease = disease_map["isTract"]
        county_mode = "tract_county" if is_tract_disease else "county"
        geometries = self.get_geometries(state=state, county=county)
        df_county = self.get_filtered_df(
            mode=county_mode, state=state, county=county, disease=disease
        )
        df_tracts = (
            self.get_filtered_df(
                mode="tract", state=state, county=county, disease=disease
            )
            if county_mode == "tract_county"
            else df_county
        )
        return {
            "counties_data": None if county != "all" else df_county.to_dict("list"),
            "tracts_data": df_tracts.to_dict("list"),
            "geometries": geometries,
        }
