# See examples on https://developer.dataiku.com/latest/concepts-and-examples/llm-mesh.html#fine-tuning

## Variables

base_model_name = ""
assert base_model_name, "please specify a base LLM, it must be available on HuggingFace hub"

connection_name = ""
assert connection_name, "please specify a connection name, the fine-tuned LLM will be available from this connection"

# these columns must be in the input dataset
user_message_column = "user_message"
assistant_message_column = "assistant_message"
columns = [user_message_column, assistant_message_column]

system_message_column = ""  # optional
static_system_message = ""  # optional
if system_message_column:
    columns.append(system_message_column)

## Code

import datasets
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer

from dataiku import recipe
from dataiku.llm.finetuning import formatters

# turn Dataiku datasets into SFTTrainer datasets
training_dataset = recipe.get_inputs()[0]
df = training_dataset.get_dataframe(columns=columns)
train_dataset = datasets.Dataset.from_pandas(df)

validation_dataset = None
eval_dataset = None
if len(recipe.get_inputs()) > 1:
    validation_dataset = recipe.get_inputs()[1]
    df = validation_dataset.get_dataframe(columns=columns)
    eval_dataset = datasets.Dataset.from_pandas(df)

# load the base model and tokenizer
model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token

formatting_func = formatters.ConversationalPromptFormatter(tokenizer.apply_chat_template, *columns)

# fine-tune using SFTTrainer
saved_model = recipe.get_outputs()[0]
with saved_model.create_finetuned_llm_version(connection_name) as finetuned_llm_version:
    # Customize here. Requirement: put a transformers model in safetensors format into finetuned_llm_version.working_directory.
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_func,
        args=SFTConfig(
            output_dir=finetuned_llm_version.working_directory,
            save_safetensors=True,
            num_train_epochs=1,
            logging_steps=1,
            max_seq_length=min(tokenizer.model_max_length, 1024),
            eval_strategy="steps" if eval_dataset else "no",
        ),
        peft_config=LoraConfig(
            r=16,
            lora_alpha=32,
            lora_dropout=0.05,
            task_type="CAUSAL_LM",
            target_modules="all-linear",
        ),
    )
    trainer.train()
    trainer.save_model()

    config = finetuned_llm_version.config
    config["trainingDataset"] = training_dataset.short_name
    if validation_dataset:
        config["validationDataset"] = validation_dataset.short_name
    config["userMessageColumn"] = user_message_column
    config["assistantMessageColumn"] = assistant_message_column
    config["systemMessageColumn"] = system_message_column
    config["staticSystemMessage"] = static_system_message
    config["batchSize"] = trainer.state.train_batch_size
    config["eventLog"] = trainer.state.log_history
