import pandas as pd
import numpy as np
import os
import io
import tempfile
from pathlib import Path
import shutil
import pickle

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import OrdinalEncoder
from sklearn.metrics import r2_score

from statsmodels.stats.outliers_influence import variance_inflation_factor
    

# SAVED AND LOAD SCALERS 
class SaveLoad(): 
        
    def save_pickle(obj_to_save,experiment_folder_handle=None,scaler_path=None):
        with experiment_folder_handle.get_writer(scaler_path) as writer:
            writeable = pickle.dumps(obj_to_save)
            writer.write(writeable)
   
    def load_pickle(experiment_folder_handle=None, scaler_path=None): 
        with tempfile.TemporaryDirectory() as tmpdirname:
            local_file_path = tmpdirname +'/'+scaler_path
    
            #Create file localy
            if not os.path.exists(os.path.dirname(local_file_path)):
                os.makedirs(os.path.dirname(local_file_path))
            #Copy file from remote to local
            with experiment_folder_handle.get_download_stream(scaler_path) as f_remote, open(local_file_path,'wb') as f_local:
                shutil.copyfileobj(f_remote,f_local)

            #Load model from local repository
            obj_loaded = pickle.load(open(tmpdirname+'/'+scaler_path,'rb'))
            return obj_loaded 


class CustomMeanScaler:
    
    """This processor devides the initial column by its mean to have a mean of 1 
        Input: pd.Series
        Easy to implement in visual ML pipelines. 
    """

    def __init__(self):
        self.mean_value   = None
        self.feature_name = None
        
    def fit(self,df):
        self.feature_name = df.name
        self.mean_value   = np.nanmean(df.values)
        
    def transform(self, df)->pd.DataFrame:
        
        df_transformed = pd.DataFrame(df.values/self.mean_value, columns=[self.feature_name])
        return df_transformed
    
    
class CustomScalerSuper(OrdinalEncoder):
    """
        Example with the OrdinalEncoder inherited class
    """
    
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def transform(self, X, y=None):
        self.names = ["original_column"]
        print(self.feature_names_in_)
        
        print(X)
        transformed_X = X/np.nanmean(X) 
        new_X         = pd.DataFrame(transformed_X, columns=self.names)
        
        # print(new_X)
        return new_X
 
    
# Metrics to compute forecasted and actual target values
def compute_metrics(obs,pred,model_name="None") -> dict:
    
    """
    :param obs: pd.Series of observed values
    :param pred: pd.Series of forecasted values
    :return: dict with the following metrics: 
        accuracy measures
        MAPE
        R2
        MSE
        
        These are not Bayesian metrics (Bayesian one you need to take the samples/chains and compute metric per chain, then average to obtain STD)
        Returns: 
        a dict with the following metrics:
    """

    # obs, pred = np.array(obs.dropna()), np.array(pred.dropna())

    assert len(obs) == len(pred), f'accuracy(): obs len is {len(obs)} but preds len is {len(pred)}'

    # flatten arrays if they are not 1D
    if len(obs.shape) > 1:
        obs = obs.flatten()
    if len(pred.shape) > 1:
        pred = pred.flatten()
        
    rmse = np.sqrt(np.mean((obs - pred)**2))
    mse  = np.nanmean((obs - pred)**2)
    mape = np.nanmean(np.abs((obs - pred) / obs)) 
    r2   = r2_score(obs,pred) # coefficient of determination
    R2   = np.corrcoef(obs,pred)[0,1] # Pearsons' correlation between prediction and target 
    
    metrics = {'model_name':model_name,'mse': mse, 'rmse': rmse, 'mape':mape,'r2':r2,'R2':R2}
    return metrics

def compute_vif(X) -> pd.DataFrame:

    """
    Compute the Variance Inflation Factor (VIF) for each feature in a dataframe.
    """
    vif = pd.DataFrame()
    vif["features"] = X.columns
    vif["VIF Factor"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif
