import json
import logging
import os
import sys
from dataclasses import dataclass
from typing import List, Tuple, Optional, Any, Dict

from trl import apply_chat_template, SFTConfig

logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
logger = logging.getLogger("local_hf_fine_tuning_recipe")

import torch

from dataiku.huggingface.torch_utils import is_bfloat16_supported_with_cuda
from dataiku.llm.finetuning.metrics import read_training_metrics
from dataiku import Dataset as DkuDataset
from dataiku.base.folder_context import build_folder_context, FolderContext
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import dkujson
from dataiku.core.model_provider import get_model_from_cache

DEFAULT_VALUES = {
    "nb_epochs": 3.0,
    "lora_rank": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "neftune_noise_alpha": 5,
    "initial_learning_rate": 5e-5,
    "quantization_mode": "Q_4BIT",
    "checkpoint_mode": "KEEP_BEST_ONLY"
}

IDEAL_NUMBER_OF_EVALUATIONS_BY_EPOCHS = 3 # Number of evaluation we want to perform for a single epoch


def get_target_modules(model):
    """
    We use this as a replacement from `all-linear` as other modules are currently not supported.
    https://github.com/huggingface/peft/blob/52684952136b3dcc1120834f8af2d8f5aaa1c16a/src/peft/tuners/lora/model.py#L346
    Args:
        model (transformers.PreTrainedModel): The loaded pre-trained model.

    Returns:
        List[str]: filtered list of modules we want to target.
    """
    from transformers import Conv1D

    layer_names = []

    # Recursively visit all modules and submodules
    for module_name, module in model.named_modules():
        # Check if the module is an instance of the specified layers
        if isinstance(module, (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, Conv1D)):
            # Module names examples:
            #   - model.layers.1.self_attn.o_proj
            #   - model.layers.1.mlp.gate_proj
            parts = module_name.split(".")
            if len(parts) >= 5:
                layer_names.append(parts[4])
    return list(set(layer_names))


@dataclass
class UserDefinedHyperparameters:
    nb_epochs: float = DEFAULT_VALUES["nb_epochs"]
    lora_rank: int = DEFAULT_VALUES["lora_rank"]
    lora_alpha: int = DEFAULT_VALUES["lora_alpha"]
    lora_dropout: float = DEFAULT_VALUES["lora_dropout"]
    neftune_noise_alpha: float = DEFAULT_VALUES["neftune_noise_alpha"]
    initial_learning_rate: float = DEFAULT_VALUES["initial_learning_rate"]
    quantization_mode: str = DEFAULT_VALUES["quantization_mode"]
    checkpoint_mode: str = DEFAULT_VALUES["checkpoint_mode"]

    @staticmethod
    def from_recipe_desc(hyperparameters: Dict[str, Any]) -> 'UserDefinedHyperparameters':
        if hyperparameters["useDefaults"]:
            # We return default values.
            if torch.cuda.is_available():
                quantization_mode = DEFAULT_VALUES["quantization_mode"]
            else:
                quantization_mode = "NONE"
            return UserDefinedHyperparameters(quantization_mode=quantization_mode)
        return UserDefinedHyperparameters(
            nb_epochs=hyperparameters.get("nbEpochs", DEFAULT_VALUES["nb_epochs"]),
            lora_rank=hyperparameters.get("localHuggingFace", {}).get("r", DEFAULT_VALUES["lora_rank"]),
            lora_alpha=hyperparameters.get("localHuggingFace", {}).get("loraAlpha", DEFAULT_VALUES["lora_alpha"]),
            lora_dropout=hyperparameters.get("localHuggingFace", {}).get("loraDropout", DEFAULT_VALUES["lora_dropout"]),
            neftune_noise_alpha=hyperparameters.get("localHuggingFace", {}).get("neftuneNoiseAlpha", DEFAULT_VALUES["neftune_noise_alpha"]),
            initial_learning_rate=hyperparameters.get("localHuggingFace", {}).get("initialLearningRate", DEFAULT_VALUES["initial_learning_rate"]),
            quantization_mode=hyperparameters["localHuggingFace"].get("quantization", DEFAULT_VALUES["quantization_mode"]),
            checkpoint_mode=hyperparameters["localHuggingFace"].get("checkpointMode", DEFAULT_VALUES["checkpoint_mode"]),
        )


def use_flash_attn(use_gpu: bool = False) -> Tuple[bool, bool]:
    """
    Defines if we are using flash-attention 1 or 2 during fine-tuning:
      - https://huggingface.co/docs/trl/v0.9.4/en/sft_trainer#using-flash-attention-and-flash-attention-2
    Args:
        use_gpu (bool): Whether we are running fine-tuning under GPU or not.

    Returns:
        Tuple[bool, bool]: A tuple whose elements determine if we use flash attention 1 (first) and flash attention 2 (second).
    """
    flash_attn_1, flash_attn_2 = False, False
    if not use_gpu:
        logger.info("No GPU device. No use of flash-attention.")
        return flash_attn_1, flash_attn_2
    try:
        import optimum
        flash_attn_1 = True
        logger.info("Using flash-attn-1 (`optimum` package)")
    except ImportError:
        logger.info("'optimum' package not found. No use of flash attention 1. Install the package to use it.")
    try:
        import flash_attn
        flash_attn_2 = True
        logger.info("Using flash-attn-2 (`flash-attn` package)")
    except ImportError:
        logger.info("'flash-attn' package not found. No use of flash attention 2. Install the package to use it.")
    return flash_attn_1, flash_attn_2


def start_training(pre_trained_model,
                   sft_config: SFTConfig,
                   dataset,
                   evaluation_dataset,
                   peft_config,
                   tokenizer,
                   use_flash_attn1: bool,
                   gradient_checkpointing: bool):
    """
    The actual training process
    Args:
        pre_trained_model (transformers.PreTrainedModel): The loaded pre-trained model.
        sft_config (SFTConfig): The TRL SFTConfig.
        dataset (datasets.Dataset): The model training dataset.
        evaluation_dataset (datasets.Dataset): The model evaluation dataset.
        peft_config (LoraConfig): The model LoRa adapter config.
        tokenizer (AutoTokenizer): The model tokenizer.
        gradient_checkpointing (bool): Whether to use gradient checkpointing or not. This is useful because we first try not to use it and then recover when there are OOM issues.

    Returns:
        transformers.Trainer: The trained fine-tuned model.
    """
    from trl import SFTTrainer

    peft_config.target_modules = get_target_modules(pre_trained_model)
    sft_config.gradient_checkpointing = gradient_checkpointing
    sft_config.gradient_checkpointing_kwargs={'use_reentrant': False}  # sc-197921 : needed for torch 2.4+ compatibility
    trainer = SFTTrainer(
        pre_trained_model,
        args=sft_config,
        train_dataset=dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}),
        eval_dataset=evaluation_dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) if evaluation_dataset is not None else None,
        peft_config=peft_config,
        tokenizer=tokenizer,
    )
    if use_flash_attn1:
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            trainer.train()
    else:
        trainer.train()
    return trainer


def supports_mps(saved_model_input_directory: Optional[FolderContext] = None, hf_id: Optional[str] = None):
    """
    Utility method to determine whether the base model supports MPS or not. If a saved model is specified, we read the adapter config to determine the initial base model id.
    Args:
        saved_model_input_directory (Optional[FolderContext]): The folder path of the fine-tuning base model. This is optional as an HuggingFace model id can be picked instead.
        hf_id (Optional[str]): The HuggingFace id of the fine-tuning base model. This is optional as a previously Dataiku fine-tuned saved model can be picked instead.

    Returns:
        bool: Whether the model supports MPS backend or not.
    """
    huggingface_hf_id = hf_id
    if saved_model_input_directory is not None:
        # modeling_llama.py uses autocast which doesn't support mps
        with saved_model_input_directory.get_file_path_to_read("adapter_config.json") as adapter_config_filename:
            with open(adapter_config_filename, 'r') as f:
                adapter_config = json.load(f)
                huggingface_hf_id = adapter_config["base_model_name_or_path"]
    return "llama" not in huggingface_hf_id and "gemma" not in huggingface_hf_id


def build_training_args(output_dir: FolderContext,
                        user_hyperparameters: UserDefinedHyperparameters,
                        evaluation_strategy: str,
                        max_seq_length: int,
                        eval_steps: int,
                        saved_model_input_directory: Optional[FolderContext] = None,
                        hf_id: Optional[str] = None,
                        bnb_enabled: Optional[bool] = True):
    """
    Build the training arguments object. This method takes into account the system setup in order to correctly pick float representations and the backend.
    Args:
        output_dir (FolderContext): The folder context where checkpoints are stored during training.
        user_hyperparameters (UserDefinedHyperparameters): The hyperparameters that are specifically specified by the end-user.
        evaluation_strategy (str): This is set to "no" or "steps" if we have an evaluation dataset.
        max_seq_length (int): The maximum sequence length.
        eval_steps (int): The number of steps between each log.
        saved_model_input_directory (Optional[FolderContext]: The fine-tuning base model folder context. It is useful to read the "true" hf id to determine mps support.
        hf_id (Optional[str]): The HuggingFace id of the fine-tuning base model. This is optional as a previously Dataiku fine-tuned saved model can be picked instead.
        bnb_enabled (Optional[bool]): Whether bitsandbytes is supported and enabled or not

    Returns:
        Tuple[SFTConfig, Dict[str, str], bool, bool]: The training arguments, the device map to be used during training, use flash-attention or flash-attention 2.
    """
    from trl import SFTConfig

    device_map = "auto"
    use_cpu = False
    optimizer = "paged_adamw_8bit" if bnb_enabled else "adamw_torch"
    use_gpu = False
    if torch.cuda.is_available():
        bf16 = is_bfloat16_supported_with_cuda()
        fp16 = not bf16
        use_gpu = True
        logger.info("Using CUDA backend")
    else:
        bf16 = fp16 = False
        optimizer = "adamw_torch"
        # MPS backend
        if torch.backends.mps.is_available() and supports_mps(saved_model_input_directory, hf_id):
            logger.info("Using MPS backend")
            device_map = {"": "mps"}
        # CPU fallback
        else:
            logger.info("Using CPU backend")
            use_cpu = True

    use_flash_attn1, use_flash_attn2 = use_flash_attn(use_gpu)
    save_total_limit = None
    if user_hyperparameters.checkpoint_mode == "KEEP_BEST_ONLY":
        save_total_limit = 1
    if user_hyperparameters.checkpoint_mode == "KEEP_BEST_AND_LAST_ONLY":
        save_total_limit = 2

    sft_config = SFTConfig(
        output_dir=output_dir,
        fp16=fp16,
        bf16=bf16,
        auto_find_batch_size=True,
        save_safetensors=True,
        eval_strategy=evaluation_strategy,
        eval_steps=max(1, eval_steps),  # Logging step can't be 0. If we put a decimal number it will be a % of step, so we set it to 1 if possible.
        logging_steps=max(1, eval_steps//5),  # We log a bit more train logs

        # Maximum number of saved checkpoints, progressively deletes old checkpoints
        save_total_limit=save_total_limit,

        # Has to be a multiple of eval_steps, we want the maximum number of saves possible (depending on save_total_limit, they may be progressively deleted)
        save_steps=max(1, eval_steps),
        # When set to true alongside save_total_limit=1, we should have at most two saved checkpoints: the best and the last
        load_best_model_at_end=evaluation_strategy != "no",

        optim=optimizer,
        use_cpu=use_cpu,
        neftune_noise_alpha=None if user_hyperparameters.neftune_noise_alpha == 0 else user_hyperparameters.neftune_noise_alpha,
        max_seq_length=max_seq_length,
        num_train_epochs=user_hyperparameters.nb_epochs,
        learning_rate=user_hyperparameters.initial_learning_rate,
        packing=True,
        eval_packing=evaluation_strategy != "no"
    )

    logger.info(f"Using SFTConfig: {sft_config}")
    logger.info(f"Using device map: {device_map}")

    return sft_config, device_map, use_flash_attn1, use_flash_attn2


def build_tokenizer(pretrained_model_name_or_path: str, token: str):
    """
    Args:
        pretrained_model_name_or_path (str): The path of the fine-tuning base model. This can be a folder context path or an HuggingFace model id.
        token (str): The HuggingFace token from the connection.
    :return:
        transformers.PreTrainedTokenizer: The tokenizer from the base model.
    """
    from transformers import AutoTokenizer

    logger.info("Loading tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, token=token)

    # Explanations here on why this is ok.
    # https://stackoverflow.com/questions/76446228/setting-padding-token-as-eos-token-when-using-datacollatorforlanguagemodeling-fr
    tokenizer.pad_token = tokenizer.eos_token

    return tokenizer


def load_dataset(ds_name, desc):
    import datasets

    input_dataset = DkuDataset(ds_name)
    dataset_df = input_dataset.get_dataframe()
    if desc["systemMessageMode"] == "DYNAMIC":
        # System message is optional
        desc[desc["systemMessageColumn"]] = dataset_df[desc["systemMessageColumn"]].fillna("")
    if dataset_df[desc["promptColumn"]].isnull().any():
        raise Exception(f"Prompt column {desc['promptColumn']} has some empty values.")
    if dataset_df[desc["completionColumn"]].isnull().any().any():
        raise Exception(f"Completion column {desc['completionColumn']} has some empty values.")
    dataset_df = dataset_df.astype(str)
    ds = datasets.Dataset.from_pandas(dataset_df)

    initial_columns = ds.column_names

    def to_conversational(example):
        has_system_prompt = desc["systemMessageMode"] != "NONE"
        messages = [
            {"role": "user", "content": example[desc["promptColumn"]]},
            {"role": "assistant", "content": example[desc["completionColumn"]]}
        ]

        if has_system_prompt:
            if desc['systemMessageMode'] == "STATIC":
                system_message = desc["systemMessage"]
            else:
                system_message = example[desc["systemMessageColumn"]]
            messages = [{"role": "system", "content": system_message}] + messages

        return {
            "messages": messages
        }
    ds = ds.map(to_conversational, remove_columns=initial_columns)
    return ds

def main(
    output_sm_folder_path: str,
    input_training_dataset_name: str,
    recipe_desc: Dict[str, Any],
    evaluation_dataset_name: Optional[str] = None,
    hf_id: Optional[str] = None,
    sm_context_folder_path: Optional[str] = None,
    base_model_in_cache: Optional[str] = None
):
    """
    Fine-tune a LLM base model from HuggingFace. The fine-tuned model is stored as an adapter in a savedmodel folder.

    Args:
        output_sm_folder_path (str): The folder path of the fine-tuned saved model.
        input_training_dataset_name (str): The Dataiku training dataset name.
        recipe_desc (Dict[str, Any]): The json description of the fine-tuning recipe.
        evaluation_dataset_name (Optional[str]): The Dataiku evaluation dataset name
        hf_id (Optional[str]): The HuggingFace id of the fine-tuning base model. This is optional as a previously Dataiku fine-tuned saved model can be picked instead.
        sm_context_folder_path (Optional[str]): The folder path of the fine-tuning base model. This is optional as an HuggingFace model id can be picked instead.
        base_model_in_cache (Optional[str]): Name of the base model that is available in DSS model cache
    Returns:
        None
    """
    logger.info("Starting local HF fine-tuning")

    # DSS model cache support
    base_model_path_in_cache = None
    if base_model_in_cache:
        try:
            base_model_path_in_cache = get_model_from_cache(base_model_in_cache)
            os.environ["TRANSFORMERS_OFFLINE"] = "1"
            logger.info("Using DSS model cache %s", base_model_path_in_cache)
        except:
            logger.exception("Failed to load the model from DSS model cache")

    if not base_model_path_in_cache:
        from dataiku.huggingface.utils import enable_hf_transfer
        enable_hf_transfer(logger)

    # these imports must be done after we alter environment variables
    from peft import LoraConfig
    from transformers import BitsAndBytesConfig

    fine_tuned_model_folder_context = build_folder_context(output_sm_folder_path)

    logger.info("Loading training dataset")
    training_dataset = load_dataset(input_training_dataset_name, recipe_desc)
    logger.info(f"Training dataset loaded ({training_dataset.shape})")

    validation_dataset = None
    if evaluation_dataset_name is not None:
        logger.info("Loading validation dataset")
        validation_dataset = load_dataset(evaluation_dataset_name, recipe_desc)
        logger.info(f"Validation dataset loaded ({validation_dataset.shape})")

    logger.info(f"Fine-tuning HF base model {hf_id if hf_id is not None else sm_context_folder_path}")

    user_hyperparameters = UserDefinedHyperparameters.from_recipe_desc(recipe_desc["hyperparameters"])

    logger.info("Creating quantization config")
    quantization = user_hyperparameters.quantization_mode
    if quantization == "Q_4BIT":
        logger.info("4bit quantization selected")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    elif quantization == "Q_8BIT":
        logger.info("8bit quantization selected")
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True
        )
    else:
        logger.info(f"No quantization selected: value was: {quantization}")
        bnb_config = None

    # Quite a lot of parameters we might want to give access to ?
    peft_config = LoraConfig(
        r=user_hyperparameters.lora_rank,
        lora_alpha=user_hyperparameters.lora_alpha,
        lora_dropout=user_hyperparameters.lora_dropout,
        use_rslora=False,
        bias="none",  # 'none', 'all' or 'lora_only' parameterizable ?
        task_type="CAUSAL_LM",
    )

    hf_token = os.getenv("HF_TOKEN")

    with fine_tuned_model_folder_context.get_subfolder_context("output").get_folder_path_to_write() as output_dir:
        trainer = None
        if sm_context_folder_path is not None:
            with build_folder_context(sm_context_folder_path).get_folder_path_to_read() as saved_model_folder:
                trainer = load_and_train(saved_model_folder, hf_token, user_hyperparameters, bnb_config, training_dataset, validation_dataset, peft_config, output_dir, build_folder_context(sm_context_folder_path))
        else:
            trainer = load_and_train(base_model_path_in_cache or hf_id, hf_token, user_hyperparameters, bnb_config, training_dataset, validation_dataset, peft_config, output_dir, None)

    # Then we save
    fine_tuned_model_folder_context.create_if_not_exist()
    with fine_tuned_model_folder_context.get_folder_path_to_write() as output_folder:
        logging.info(f"Saving model to {output_folder}")
        llm_step_wise_training_metrics, nb_epochs, total_steps = read_training_metrics(trainer.state.log_history)
        with fine_tuned_model_folder_context.get_file_path_to_write("llm_stepwise_training_metrics.json") as metrics_filepath:
            with open(metrics_filepath, 'w') as f:
                json.dump(llm_step_wise_training_metrics, f)
        with fine_tuned_model_folder_context.get_file_path_to_write("llm_info.json") as llm_info_filepath:
            with open(llm_info_filepath, 'w') as f:
                json.dump({
                    "batchSize": trainer.state.train_batch_size,
                    "nbEpochs": nb_epochs,
                    "totalSteps": total_steps,
                    "quantizationMode": quantization,
                    "checkpointMode": user_hyperparameters.checkpoint_mode,
                    "loraRank": user_hyperparameters.lora_rank,
                    "loraAlpha": user_hyperparameters.lora_alpha,
                    "loraDropout": user_hyperparameters.lora_dropout,
                    "neftuneNoiseAlpha": user_hyperparameters.neftune_noise_alpha,
                    "initialLearningRate": user_hyperparameters.initial_learning_rate
                }, f)

        # restore original base model name if we fine-tuned from DSS model cache
        if base_model_path_in_cache:
            trainer.model.active_peft_config.base_model_name_or_path = hf_id

        trainer.save_model(output_dir=output_folder)


def load_and_train(pretrained_model_name_or_path: str,
                   token: str,
                   user_hyperparameters: UserDefinedHyperparameters,
                   bnb_config,
                   training_dataset,
                   validation_dataset,
                   peft_config,
                   output_directory: str,
                   saved_model_input_directory: Optional[FolderContext],
                   ):
    """
    Load the model and start the training job
    Args:
        pretrained_model_name_or_path (str): The path of the fine-tuning base model. This can be a folder context path or an HuggingFace model id.
        token (str): The HuggingFace token from the connection.
        user_hyperparameters (UserDefinedHyperparameters): The hyperparameters that are specifically specified by the end-user.
        bnb_config (BitsAndBytesConfig): The description of the quantization of the loaded model.
        training_dataset (datasets.Dataset): The input training dataset formatted either in instruction format or in ChatML format.
        validation_dataset (datasets.Dataset): The input evaluation dataset formatted either in instruction format or in ChatML format.
        peft_config (LoraConfig): The description of the LoRa configuration (the actual adapter to be trained).
        output_directory (str): The directory where checkpoints are stored.
        saved_model_input_directory (Optional[FolderContext]: The fine-tuning base model folder context. It is useful to read the "true" hf id to determine mps support.
    Returns:
        transformers.Trainer: The trained fine-tuned model.
    """
    from transformers import AutoModelForCausalLM

    tokenizer = build_tokenizer(pretrained_model_name_or_path, token)

    # 8 is the default batch size value of transformers, but it is decreased until it works.
    # nb_steps_by_epoch = dataset_size / batch_size
    # We want to get `k` evaluation by epoch we can compute `dataset_size / batch_size / k` which will give us the number of steps between 2 evaluations.
    # In our case the batch size is between [1, 8]. Therefore,
    # => the true number of steps is between [dataset_size / 8, dataset_size].
    # => the true number of steps between 2 evaluations is [dataset_size / 8 / k, dataset_size / k]
    # We always set it to `dataset_size / 8 / k` which means we get between [k, k * 8] evaluations by epochs.
    # Example (dataset_size = 3200, batch_size = 4, k = 3):
    # nb_steps_by_epoch = 3200 / 4 = 800
    # eval_steps = dataset_size / 8 / 3 = 3200 / 8 / 3 = 133
    # We will have an evaluation step every 133 steps. Since we have 800 steps by epochs, that will give us in that case 6 evaluations steps by epoch.
    eval_steps = training_dataset.num_rows // 8 // IDEAL_NUMBER_OF_EVALUATIONS_BY_EPOCHS
    logger.info("Will compute evaluation loss every {} steps.".format(eval_steps))

    # Flash attn 1 not used for now because packing is disabled.
    sft_config, device_map, use_flash_attn1, use_flash_attn2 = build_training_args(
        output_directory,
        user_hyperparameters,
        "no" if validation_dataset is None else "steps",
        min(tokenizer.model_max_length, 1024),
        eval_steps,
        saved_model_input_directory,
        pretrained_model_name_or_path,
        bnb_enabled=bool(bnb_config),
    )

    logger.info(f"Loading model")

    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        device_map=device_map,
        quantization_config=bnb_config,
        token=token,
        attn_implementation="flash_attention_2" if use_flash_attn2 else None
    )

    logger.info(f"{pretrained_model_name_or_path} loaded. Preparing trainer")
    try:
        logger.info("Starting training without gradient checkpointing.")
        trainer = start_training(model, sft_config, training_dataset, validation_dataset, peft_config, tokenizer, use_flash_attn1, gradient_checkpointing=False)
    except Exception:
        logger.warning("Trying to recover CUDA OOM. Gradient checkpointing = True")
        trainer = start_training(model, sft_config, training_dataset, validation_dataset, peft_config, tokenizer, use_flash_attn1, gradient_checkpointing=True)

    return trainer


if __name__ == "__main__":
    read_dku_env_and_set()

    [_, model_folder_context_path, input_dataset_name, recipe_desc_filepath, validation_dataset_name, base_model_hf_id,
     base_model_sm_context_folder_path, base_model_in_cache] = sys.argv

    with ErrorMonitoringWrapper():
        main(
            model_folder_context_path,
            input_dataset_name,
            dkujson.load_from_filepath(recipe_desc_filepath),
            validation_dataset_name or None,
            base_model_hf_id or None,
            base_model_sm_context_folder_path or None,
            base_model_in_cache or None
        )
