# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
from dataiku import recipe
from datasets import Dataset
import torch
import huggingface_hub
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig

MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
MODEL_REVISION = "c72e5d1908b1e2929ec8fc4c8820e9706af1f80f"
connection_name = "local-huggingface"

train_dataset = Dataset.from_pandas(
    dataiku.Dataset("sft_train").get_dataframe()
)
eval_dataset = Dataset.from_pandas(
    dataiku.Dataset("sft_validation").get_dataframe()
)

auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "hf_token":
        huggingface_hub.login(token=secret["value"])
        break

saved_model = recipe.get_outputs()[0]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    revision=MODEL_REVISION,
    device_map="auto",
    quantization_config=bnb_config,
)
model.config.use_cache = False # Because the model will change as it is fine-tuned

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    revision=MODEL_REVISION
)
tokenizer.pad_token = tokenizer.eos_token

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def formatting_func(example):
    """
    Convert a batch of training examples into the model's instruction format.
    """
    output_texts = []
    for i in range(len(example['prompt'])):
        text = f"[INST] {example['prompt'][i]} [/INST] {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = tokenizer.encode("[/INST]", add_special_tokens=False)
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

with saved_model.create_finetuned_llm_version(connection_name) as finetuned_llm_version:
    # Define the training parameters
    train_conf = SFTConfig(
        output_dir=finetuned_llm_version.working_directory,
        save_safetensors=True,
        gradient_checkpointing=True,
        num_train_epochs=10,
        per_device_train_batch_size=16,
        optim="paged_adamw_8bit",
        logging_steps=50,
        eval_strategy="steps",
        neftune_noise_alpha=5
    )

    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        task_type="CAUSAL_LM"
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_func,
        args=train_conf,
        peft_config=peft_config,
    )

    # Fine-tune the model
    trainer.train()

    # Save the fine-tuned model in the managed folder
    trainer.save_model()
    config = finetuned_llm_version.config
    config["batchSize"] = trainer.state.train_batch_size
    config["eventLog"] = trainer.state.log_history