# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
from dataiku import pandasutils as pdu
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, AutoModel
import faiss
import pickle

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Read recipe inputs
studies_for_similarity = dataiku.Dataset("studies_for_similarity")

# Write recipe outputs
similarity_index_folder = dataiku.Folder("O4J03zf7")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Load the BERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")


def embed_text(df, column, batch_size=32, device='cuda' if torch.cuda.is_available() else 'cpu'):
    # Tokenize the text in batches
    tokenized_texts = tokenizer(
        list(df[column].fillna('NA')),
        padding=True,
        truncation=True,
        return_tensors='pt',
        max_length=128)
    # Move tokenized inputs to appropriate device
    tokenized_texts = {key: value.to(device) for key, value in tokenized_texts.items()}

    # Forward pass through the BERT model in batches
    column_embeddings = []
    num_batches = (len(tokenized_texts['input_ids']) + batch_size - 1) // batch_size
    for i in tqdm(range(num_batches), desc=f'Embedding {column}'):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(tokenized_texts['input_ids']))
        batch_tokenized_texts = {
            key: value[start_idx:end_idx] for key, value in tokenized_texts.items()}
        with torch.no_grad():
            outputs = model(**batch_tokenized_texts)
            # Use mean pooling to get sentence embeddings
            batch_embeddings = outputs.last_hidden_state.mean(dim=1)

        # Normalize the batch embeddings
        normalized_batch_embeddings = normalize_vector(batch_embeddings)

        column_embeddings.append(normalized_batch_embeddings)

    # Concatenate the normalized embeddings from all batches
    embedded_vector = torch.cat(column_embeddings, dim=0)

    return embedded_vector


def embed_categorical_variable(df, column):
    # Set the seed for random number generator
    torch.manual_seed(42)

    # Check if CUDA (GPU) is available and set seed for GPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    # Extracting categorical variable from DataFrame
    categorical_variable = df[column].astype(str).values

    # Label Encoding
    label_encoder = load_label_encoder(column)
    label_encoded = label_encoder.transform(categorical_variable)

    # Debugging: Print unique values and range of label indices
#     print("column name: ", column)
#     print("Unique categories:", len(label_encoder.classes_))
#     print("Range of label indices:", min(label_encoded), "-", max(label_encoded))

    # Convert to a tensor
    tensor_encoded = torch.tensor(label_encoded)

    # Debugging: Print tensor_encoded
#     print("Tensor encoded:", tensor_encoded)

    # Define the embedding layer
    num_categories = len(label_encoder.classes_)  # Number of unique categories
    embedding_dim = num_categories  # Dimensionality of the embedding vectors
    embedding_layer = nn.Embedding(num_categories, embedding_dim)

#     print("Num categories:", num_categories)
#     print("Embedding dim:", embedding_dim)

    # Embedding the categorical variable

    embedded_data = embedding_layer(tensor_encoded)
    normalized_vector = nn.functional.normalize(embedded_data, p=2, dim=-1)

    return normalized_vector


def normalize_vector(vector):
    vector_norm = torch.norm(vector, p=2, dim=-1, keepdim=True)
    normalized_vector = vector / vector_norm
    return normalized_vector


def create_similarity_index(vector):
    data_dimension = vector.shape[1]
    index = faiss.IndexFlatIP(data_dimension)

    # Convert PyTorch tensor to NumPy array
    numpy_array = vector.cpu().detach().numpy()
    # Convert NumPy array to contiguous array with data type np.float32
    contiguous_array = np.ascontiguousarray(numpy_array, dtype=np.float32)

    index.add(contiguous_array)

    return index


def load_label_encoder(column):
    file_name = column + "_label_encoder.pkl"
    with similarity_index_folder.get_download_stream(file_name) as f:
        data = f.read()
        encoder = pickle.loads(data)
        return encoder


def create_label_encoders():
    labels = {
        'age_group_label': ['0-1-1', '1-1-1', '1-1-0', '0-1-0', '0-0-1', '1-0-0'],
        'Sex': ['MALE', 'ALL', 'FEMALE'],
        'HealthyVolunteers': ['False',  'True']
    }

    for key, value in labels.items():
        label_encoder = LabelEncoder()
        label_encoded = label_encoder.fit_transform(value)
        with similarity_index_folder.get_writer(f"{key}_label_encoder.pkl") as w:
            encoder_pk = pickle.dumps(label_encoder)
            w.write(encoder_pk)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def embed_features(df, columns):
    embeddings = []
    unstructured_cols = {
        'BriefSummary', 'inclusion_criteria1', 'exclusion_criteria1', 'MeshTerm_Conditions'}
    categorical_cols = {
        'age_group_label', 'Sex', 'HealthyVolunteers'}

    allowed_unstructured_cols = [col for col in columns if col in unstructured_cols]
    allowed_categorical_cols = [col for col in columns if col in categorical_cols]

    for column in allowed_unstructured_cols:
        embedded_vector = embed_text(df, column)
        if column == "MeshTerm_Conditions":
            embedded_vector = embedded_vector*2
        embeddings.append(embedded_vector)

    for column in allowed_categorical_cols:
        embedded_vector = embed_categorical_variable(df, column)
        embeddings.append(embedded_vector)

    concat_embedding = torch.cat(embeddings, dim=1)
    normalized_embedding = normalize_vector(concat_embedding)
    return normalized_embedding

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Instantiate label encoders

create_label_encoders()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Instantiate index

slice_df = studies_for_similarity.get_dataframe(limit=32)
columns_for_embedding=[
    'BriefSummary','inclusion_criteria1', 'exclusion_criteria1', 'MeshTerm_Conditions',
    'age_group_label', 'Sex', 'HealthyVolunteers']
sample_embedded_vector = embed_features(slice_df, columns_for_embedding)

data_dimension = sample_embedded_vector.shape[1]
index = faiss.IndexFlatIP(data_dimension)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
nctid_arrays = []

for i, partial_dataframe in enumerate(studies_for_similarity.iter_dataframes(chunksize=3200)):
    embeded_features_tensor = embed_features(partial_dataframe, columns_for_embedding)
    # Convert PyTorch tensor to NumPy array
    numpy_array = embeded_features_tensor.cpu().detach().numpy()
    # Convert NumPy array to contiguous array with data type np.float32
    contiguous_array = np.ascontiguousarray(numpy_array, dtype=np.float32)
    index.add(contiguous_array)

    nctid_array = partial_dataframe['NCTId'].values
    nctid_arrays.append(nctid_array)
    print(f"{len(nctid_array)} studies added to the index")

with similarity_index_folder.get_writer("studies_cosine_similarity.pkl") as w:
    chunk = faiss.serialize_index(index)
    index_pk = pickle.dumps(chunk)
    w.write(index_pk)

with similarity_index_folder.get_writer("nctid_array.pkl") as w2:
    id_vectors = np.concatenate(nctid_arrays)
    id_pk = pickle.dumps(id_vectors)
    w2.write(id_pk)
