import dataiku
import pickle
from .config import CsiConfig
from ..utils import (
    filter_dataset_by_nctid,
    get_keys_from_typed_dict,
    df_to_dict,
    include_sdoh,
)
from webaiku.apis.dataiku.api import dataiku_api
from ..models import (
    StudyInfo,
    StudyDesign,
    StudyCollaborator,
    StudyCondition,
    StudyIntervention,
    StudyEligibility,
    StudyArmGroup,
    StudyScore,
    StudySite,
    StudySummary,
    StudyInvestigators,
)
import pandas as pd
import numpy as np
from enum import Enum
import pickle
from typing import Dict, Any, Optional, List
from functools import lru_cache
import json


class CohortAge(str, Enum):
    CHILD = "CHILD"
    ADULT = "ADULT"
    OLDER_ADULT = "OLDER_ADULT"


TYPE_MAPPING_DATASET = {
    StudyInfo: CsiConfig.STUDY_DESCRIPTION_DATASET_NAME,
    StudyDesign: CsiConfig.STUDY_DESIGN_DATASET_NAME,
    StudyCollaborator: CsiConfig.STUDY_COLLABERATORS_DATASET_NAME,
    StudyCondition: CsiConfig.STUDY_CONDITIONS_DATASET_NAME,
    StudyIntervention: CsiConfig.STUDY_INTERVENTIONS_DATASET_NAME,
    StudyEligibility: CsiConfig.ELIGIBILITY_DATASET_NAME,
    StudyArmGroup: CsiConfig.ARMGROUP_DATASET_NAME,
    StudyScore: CsiConfig.STUDIES_W_SCORES_SDOH_DATASET_NAME
    if include_sdoh()
    else CsiConfig.STUDIES_W_SCORES_DATASET_NAME,
    StudySite: CsiConfig.STUDIES_W_SITES_JOINED_DATASET_NAME,
    StudyInvestigators: CsiConfig.STUDY_INVESTIGATORS_DATASET_NAME,
}


class StudySimilarity:
    def __init__(self):
        ## Values
        self.nctids = StudySimilarity.load_nctid_id_index()
        self.meshTerms_values = StudySimilarity.get_studies_mesh_terms(
            dataiku.Dataset(
                project_key=dataiku_api.project_key,
                name=CsiConfig.STUDY_CONDITIONS_DATASET_NAME,
            )
        )
        self.cohort_sex_values = StudySimilarity.get_cohort_sex_values(
            dataiku.Dataset(
                project_key=dataiku_api.project_key,
                name=CsiConfig.ELIGIBILITY_DATASET_NAME,
            )
        )

        self.cohort_age_values = [
            CohortAge.CHILD,
            CohortAge.ADULT,
            CohortAge.OLDER_ADULT,
        ]

        return

    
    @staticmethod
    def load_nctid_id_index():
        index_folder = dataiku.Folder(
            CsiConfig.SIMILARITY_INDEX_FOLDER_ID, project_key=dataiku_api.project_key
        )
        with index_folder.get_download_stream(CsiConfig.STUDY_IDS_INDEX_PK_FILE) as f:
            data = f.read()
            index = pickle.loads(data)
            return list(index)


    @staticmethod
    def get_studies_mesh_terms(dataset: dataiku.Dataset):
        df = dataset.get_dataframe(columns=["MeshTerm"])
        return list(pd.unique(df["MeshTerm"]))

    @staticmethod
    def get_cohort_sex_values(dataset: dataiku.Dataset):
        df = dataset.get_dataframe(columns=["Sex"])
        return list(pd.unique(df["Sex"]))

    @staticmethod
    def get_similar_studies(nctid: str):
        return

    @staticmethod
    @lru_cache(maxsize=30)
    def get_study_summary(nctid: str):
        study_info: Optional[StudyInfo] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyInfo, keep_first=True
        )
        study_design: Optional[StudyDesign] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyDesign, keep_first=True
        )
        study_collaborators: Optional[
            List[StudyCollaborator]
        ] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyCollaborator, keep_first=False
        )
        study_conditions: Optional[
            List[StudyCondition]
        ] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyCondition, keep_first=False
        )
        study_interventions: Optional[
            List[StudyIntervention]
        ] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyIntervention, keep_first=False
        )
        study_eligibility: Optional[StudyEligibility] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyEligibility, keep_first=True
        )
        study_arms: Optional[List[StudyArmGroup]] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyArmGroup, keep_first=False
        )
        study_score: Optional[StudyScore] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyScore, keep_first=True
        )
        study_sites: Optional[List[StudySite]] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudySite, keep_first=False
        )
        study_investigators: Optional[
            List[StudyInvestigators]
        ] = StudySimilarity.get_model_data(
            nctid=nctid, model=StudyInvestigators, keep_first=False
        )
        return StudySummary(
            info=study_info,
            design=study_design,
            collaborators=study_collaborators,
            conditions=study_conditions,
            interventions=study_interventions,
            eligibility=study_eligibility,
            arms=study_arms,
            score=study_score,
            sites=study_sites,
            investigators=study_investigators,
        )

    @staticmethod
    def get_model_data(nctid: str, model: type, keep_first: bool = True):
        datsetName = TYPE_MAPPING_DATASET.get(model, None)
        if datsetName is None:
            raise ValueError(
                f"No dataset defined for the provided model {model.__name__}"
            )
        df_filtered = filter_dataset_by_nctid(
            datasetName=datsetName, nctid=nctid, columns=get_keys_from_typed_dict(model)
        )
        json_data = df_to_dict(df_filtered, keep_first)
        if not json_data is None:
            if keep_first:
                return model(**json_data)
            else:
                return [model(**item) for item in json_data]
        return None


study_similarity = StudySimilarity()
