import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
from itertools import combinations, product
import math
from dotenv import load_dotenv
import warnings
import os
import datamol as dm
import rdkit
from rdkit import Chem, DataStructs
from rdkit.Chem import MACCSkeys, Descriptors, AllChem, QED
from molfeat.trans import MoleculeTransformer
from molfeat.trans.fp import FPVecTransformer
from molfeat.trans.pretrained.hf_transformers import PretrainedHFTransformer, HFModel
import torch
import transformers
from transformers import AutoModel, AutoTokenizer
from sklearn.manifold import TSNE
from tqdm.auto import tqdm

def pIC50_bioactivity(IC50, digits=3):
    """
    Calculates the normalized molecular bioactivity value (pIC50) from the half-maximal inhibitory concentration (IC50).
    Args:    ic50: The IC50 value (M), must be positive, digits: The number of decimal places to round the pIC50 value to (default: 3).
    Returns: The pIC50 value as a float.
    Raises:  ValueError: If the ic50 value is non-positive.
    """
    value = IC50
    if value <= 0:
        raise ValueError("IC50 value must be positive")
    return np.round(np.log10(value*10**-9)*-1, decimals = digits)

def molecular_featurizer(smiles, transformer):
    """
    Generates molecular fingerprint representations from a list of SMILES strings.
    
    Args:
        smiles (list): A list of SMILES strings.
        transformer (str): The type of featurizer to use. Supported options:
            - 'ecfp': Extended Connectivity Fingerprints
            - 'mordred': Mordred descriptors
            - 'MACCS': MACCS keys
            - 'ChemBERTa': ChemBERTa transformer model
            - 'Roberta': Roberta transformer model
    
    Returns:
        np.ndarray: A NumPy array of molecular fingerprint representations, 
                    with NaN values for records that failed to process.
    
    Raises:
        ValueError: If the transformer type is invalid.
    """
    smiles_array = np.array(smiles)  # Convert input to NumPy array
    smiles_list = smiles_array.tolist()

    if transformer in ['ecfp', 'mordred']:  
        trans = FPVecTransformer(kind=transformer, n_jobs=1)
        try:
            trans_feats = trans(smiles_array, ignore_errors=True)  # Allow skipping errors
        except Exception:
            trans_feats = np.full((len(smiles_array), 1), np.nan) 
        feats = trans_feats[0]

    elif transformer == 'MACCS':
        mol_vecs = [Chem.MolFromSmiles(x) if x else None for x in smiles_array]
        mol_vecs = [mol if mol is not None else np.nan for mol in mol_vecs]  # Replace invalid molecules with NaN
        feats = get_MACCS_keys(mol_vecs)

    elif transformer == 'ChemBERTa':
        model_name = "DeepChem/ChemBERTa-77M-MLM"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name)
        
        features_list = []
        with torch.no_grad():
            for smiles in smiles_list:
                try:
                    inputs = tokenizer(smiles, return_tensors="pt", padding=True)
                    outputs = model(**inputs)
                    embeddings = outputs.last_hidden_state
                    averaged_embeddings = torch.mean(embeddings, dim=1)
                    features_list.append(averaged_embeddings.squeeze().cpu().numpy())
                except Exception:
                    features_list.append(np.nan)  # Handle failed tokenization

        feats = np.array(features_list)

    elif transformer == 'Roberta-Zinc':
        trans = PretrainedHFTransformer(kind='Roberta-Zinc480M-102M', notation='smiles', dtype=float)
        try:
            feats = trans(smiles_array, ignore_errors=True)  # Fix typo
        except Exception:
            feats = np.full((len(smiles_array), 1), np.nan)

    else:
        raise ValueError("Invalid featurizer type")

    return feats

def get_MACCS_keys(molecular_vectors):
    """
    Generates MACCS keys for a list of RDKit molecule objects.
    
    Args:
        molecular_vectors (list): A list of RDKit molecule objects.

    Returns:
        list: A list of valid MACCS key NumPy arrays (invalid SMILES are skipped).
    """
    fingerprint_length = 167  # MACCS keys have a fixed length of 167 bits
    feats_maccs = []
    
    for mol in molecular_vectors:
        try:
            if mol is None:
                continue  # Skip invalid molecules
            fp = MACCSkeys.GenMACCSKeys(mol)  # Generate MACCS keys
            arr = np.zeros((1,), dtype=int)  # Initialize NumPy array
            DataStructs.ConvertToNumpyArray(fp, arr)  # Convert to NumPy array
            feats_maccs.append(arr)  # Append valid fingerprint
        except Exception as e:
            print(f"Skipping invalid molecule: {e}")  # Log issue and continue
    
    return np.array(feats_maccs) if feats_maccs else np.empty((0, fingerprint_length))


def molecular_descriptors(smiles_list, descriptors):
    """
    Calculates a specified set of descriptors for a list of SMILES strings.
    Args:     smiles_list: A list of SMILES strings, descriptors: A list of valid RDKit descriptor names.
    Returns:  A list of lists, where each inner list contains the SMILES string followed by the calculated descriptor values for that molecule.
    Raises:   ValueError: If an invalid SMILES string or descriptor name is encountered.
    """
    desc_data = []

    for smi in smiles_list:
        try:
            molecule_object = Chem.MolFromSmiles(smi)
            molecule_object = Chem.AddHs(molecule_object)
            if molecule_object:
                desc_values = [getattr(Descriptors, name)(molecule_object) for name in descriptors]
                qed_score = QED.qed(molecule_object)
                desc_data.append([smi] + desc_values + [qed_score])
        except Exception as e:
            raise ValueError(f"Error processing SMILES: {smi}. Reason: {e}")
    return desc_data

def rule_of_five(molecular_weight, clogp, hydrogen_bond_donors, hydrogen_bond_acceptors):
    """
    Evaluates a compound against Lipinski's Rule of Five criteria.
    Args:     molecular_weight: Molecular weight in g/mol, clogp: Calculated octanol-water partition coefficient, hydrogen_bond_donors:count, hydrogen_bond_acceptors: count
    Returns:  A dictionary with information about the rule's evaluation: 'pass': True if all criteria are met, False otherwise.
    """
    if (molecular_weight < 500 and clogp < 5 and hydrogen_bond_donors < 5 and hydrogen_bond_acceptors < 10):
        return ('pass')
    else:
        return ('fail')

def tsne_function(num_molecules, fingerprints):
    """
    Applies t-SNE for dimensionality reduction and visualization of molecular fingerprints.
    Args:    molecular_fingerprints: A NumPy array of fingerprint vectors, num_moleculers: perplexity hyperparameter for t-SNE
    Returns: A Pandas DataFrame containing the normalized t-SNE coordinates and other relevant information.
    """
    # Perform t-SNE dimensionality reduction
    tsne = TSNE(n_components = 2, init = 'pca', random_state = 90, angle = 0.3,
                perplexity = num_molecules-1)  # Reduce to 2D for visualization
    tsne_result = tsne.fit_transform(fingerprints)
    
    # Normalize t-SNE coordinates using Min-Max scaling
    tsne_min = np.min(tsne_result, axis=0)
    tsne_max = np.max(tsne_result, axis=0)
    tsne_result_normalized = (tsne_result - tsne_min) / (tsne_max - tsne_min)
    
    # Create a Pandas DataFrame for all the molecule id and the tsne coordinates
    tsne_coordinates = pd.DataFrame({
    #"Molecule Name": chembl_id_array,
        "t-SNE X": tsne_result_normalized[:, 0],
        "t-SNE Y": tsne_result_normalized[:, 1]
        })
    return tsne_coordinates
    

