from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.DataStructs import TanimotoSimilarity
from rdkit.Chem import Draw
import matplotlib.pyplot as plt

def calculate_tanimoto_similarity(fingerprint1, fingerprint2):
    """Calculates the Tanimoto similarity between two molecular fingerprints.
    Args:    fingerprint1: The first fingerprint object, fingerprint2: The second fingerprint object.
    Returns: A tuple containing the Tanimoto similarity score between fingerprint1 and fingerprint2.
    Raises:  ValueError: If either fingerprint is None.
    """
    if fingerprint1 is None or fingerprint2 is None:
        raise ValueError("Fingerprints cannot be None.")
    tanimoto_sim = DataStructs.TanimotoSimilarity(fingerprint1, fingerprint2)
    return (tanimoto_sim)

def compute_fingerprints(dataframe, smiles_col="canonical_smiles"):
    """
    Ccalculates fingerprints for a DataFrame of molecules.
    Note that tanimoto similarity score can only be applied on binary columns hence we use RDKFingerprint
    Args:    dataframe: A pandas DataFrame containing a column of SMILES strings, smiles_col: The name of the column containing SMILES strings (default: 'canonical_smiles').
    Returns: The input DataFrame with a new column containing the calculated fingerprints.
    """
    # Calculate molecule objects and store them directly in a new column
    dataframe['mol_vectors'] = dataframe[smiles_col].apply(Chem.MolFromSmiles)

    # Calculate fingerprints separately. Chem.RDKFingerprint combines topological, funcitonal and circular fingertips
    dataframe['fingerprint'] = dataframe['mol_vectors'].apply(Chem.RDKFingerprint)
    return (dataframe)

def molecule_graph(mol_vecs, selected_mol,score_mol, legends, similarity_score):
    """
    Generates a Matplotlib figure displaying similar molecules and their information.
    Args:    mol_vecs: A vector of RDKit molecule objects, selected_molecule: The reference molecule object, score_mol: The predicted score for the selected molecule.
             title: The main title for the figure, legends: A list of legends for each molecule.
             similarity_score: A list of similarity scores for each molecule, figsize: The desired figure size (default: (10, 5)).
    Returns: A Matplotlib figure object.
    """
    # Create a Matplotlib figure
    # ==================================================================================================================
    figure, axes = plt.subplots(nrows=1, ncols=len(mol_vecs), figsize=(len(mol_vecs) * 5, 5))

    # Add a main title to the figure
    # ==================================================================================================================
    figure.suptitle("The most similar active molecules to:" + " "+ selected_mol + " "+ "with pIC50 prediction" + " "+ str(score_mol), fontsize=16)

    for ax, mol, legend, score in zip(axes, mol_vecs,legends, similarity_score):
        img = Draw.MolToImage(mol)
        ax.imshow(img)
        ax.text(0.5, -0.1, score, ha="center", transform=ax.transAxes, fontsize=12)
        ax.set_title(legend)
        ax.axis('off')
    # Check if the current molecule (mol) has any conformers, which are different 3D arrangements of the same molecule.
    # Add molecule notation
    # ==================================================================================================================
        if mol.GetNumConformers() > 0:
            ax.text(0.5, 1.05, ha="center", transform=ax.transAxes, fontsize=12)

        # Add atom notation for each atom
        # ==================================================================================================================
            for atom in mol.GetAtoms():
                atom_idx = atom.GetIdx()
                atom_symbol = atom.GetSymbol()
                atom_coords = mol.GetConformer().GetAtomPosition(atom_idx)
                ax.text(atom_coords.x, atom_coords.y, f"{atom_symbol}{atom_idx + 1}", ha="center", va="center", fontsize=12, color="red")
    return(figure)

