/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.finetuning;

import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.recipes.ParamsWithContainerizable;
import com.dataiku.dip.recipes.nlp.common.NLPRecipePayloadParams;

public class FineTuningRecipePayloadParams
extends NLPRecipePayloadParams
implements ParamsWithContainerizable {
    public String promptColumn;
    public String completionColumn;
    public SystemMessageMode systemMessageMode;
    public String systemMessage;
    public String systemMessageColumn;
    public ContainerExecSelection containerSelection = new ContainerExecSelection();
    public FineTuningHyperparameters hyperparameters = new FineTuningHyperparameters();
    public boolean deployFinetunedModel = true;
    public boolean cleanInactiveSMVDeployments = true;

    @Override
    public ContainerExecSelection getContainerSelection() {
        return this.containerSelection;
    }

    @Override
    public void setContainerSelection(ContainerExecSelection selection) {
        this.containerSelection = selection;
    }

    public static class FineTuningHyperparameters {
        public boolean useDefaults = true;
        public Integer nbEpochs;
        public LocalHuggingFaceHyperparameters localHuggingFace = new LocalHuggingFaceHyperparameters();
        public RemoteHyperparameters remoteHyperparameters = new RemoteHyperparameters();

        public void setDefaultHyperparameters() {
            this.nbEpochs = 3;
            this.localHuggingFace.setDefaultLocalHFHyperparameters();
            this.remoteHyperparameters.setDefaultRemoteHyperparameters();
        }

        public static class LocalHuggingFaceHyperparameters {
            public Integer r;
            public Integer loraAlpha;
            public Double loraDropout;
            public Boolean useRsLora;
            public Double neftuneNoiseAlpha;
            public Double initialLearningRate;
            public QuantizationMode quantization;
            public CheckpointMode checkpointMode;

            public void setDefaultLocalHFHyperparameters() {
                this.quantization = QuantizationMode.Q_4BIT;
                this.checkpointMode = CheckpointMode.KEEP_BEST_ONLY;
                this.r = 256;
                this.loraAlpha = 32;
                this.loraDropout = 0.05;
                this.useRsLora = true;
                this.neftuneNoiseAlpha = 5.0;
                this.initialLearningRate = 5.0E-5;
            }

            public static enum QuantizationMode {
                NONE,
                Q_4BIT,
                Q_8BIT;

            }

            public static enum CheckpointMode {
                KEEP_ALL_EVAL,
                KEEP_BEST_ONLY,
                KEEP_BEST_AND_LAST_ONLY;

            }
        }

        public static class RemoteHyperparameters {
            public Float learningRateMultiplier;
            public Integer batchSize;

            public void setDefaultRemoteHyperparameters() {
                this.learningRateMultiplier = Float.valueOf(2.0f);
                this.batchSize = 4;
            }
        }
    }

    public static enum SystemMessageMode {
        NONE,
        STATIC,
        DYNAMIC;

    }
}

