/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml;

import com.dataiku.dip.analysis.ml.DKUMLUtils;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.EvaluationLabelsHelper;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.core.ResolvedPreprocessingParams;
import com.dataiku.dip.analysis.model.core.SavedModelOriginInfo;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.JobAuthCtxService;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.exec.FinalCommitable;
import com.dataiku.dip.dataflow.exec.RecipeRunnerWithPayload;
import com.dataiku.dip.dataflow.graph.FlowRecipe;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.recipes.InitializableAbortableRecipeRunner;
import com.dataiku.dip.rpc.TicketBasedIntercomAPIClient;
import com.dataiku.dip.security.tickets.APITicketService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.LogsService;
import com.dataiku.dip.shaker.model.ScriptStep;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.IOException;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public abstract class AbstractTrainingRecipeRunner
implements InitializableAbortableRecipeRunner,
RecipeRunnerWithPayload,
FinalCommitable {
    @Autowired
    protected JobAuthCtxService authCtxService;
    @Autowired
    protected APITicketService ticketService;
    @Autowired
    protected DatasetsDAO datasetsDAO;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    private InitializableAbortableRecipeRunner abortableRunner = null;
    protected final JobActivity activity;
    protected final FlowRecipe recipe;
    protected File runFolder;
    protected SavedModel sm;
    protected String newVersionId;
    private static final DKULogger logger;

    public AbstractTrainingRecipeRunner(JobActivity activity) {
        this.activity = activity;
        RecipeRunnableSubgraph subgraph = (RecipeRunnableSubgraph)activity.getSubgraph();
        this.recipe = subgraph.getRecipe();
        activity.initStatus();
    }

    protected ModelUserMeta createUserMeta(SplitParams splitParams, ModelTrainInfo mti, String modelVersionNamePrefix, String legacyName, String algorithmName, String defaultInputDataset) throws IOException {
        ModelUserMeta mum = new ModelUserMeta();
        mum.name = StringUtils.isBlank((String)modelVersionNamePrefix) ? legacyName : modelVersionNamePrefix + this.sm.getNextVersionSuffix();
        mum.labels = EvaluationLabelsHelper.getTrainTimeLabels_T(this.recipe.getModel(), splitParams, this.sm, defaultInputDataset, algorithmName, mti, mum);
        return mum;
    }

    @Override
    public void finalCommit() throws Exception {
        String secret = this.ticketService.getSingleTicket().getSecret();
        try (TicketBasedIntercomAPIClient tClient = TicketBasedIntercomAPIClient.forLocalHost(secret);){
            if (this.sm.publishPolicy == SavedModel.ModelPublishPolicy.UNCONDITIONAL || !MLFlowUtils.hasValidActiveVersion(this.sm)) {
                logger.info((Object)"Setting new version as active scoring version");
                String versionToActivate = this.getVersionToActivate();
                tClient.postFormToJSON("/dip/api/tintercom/savedmodels/set-active", Void.class, new Object[]{"projectKey", this.recipe.getProjectKey(), "smId", this.sm.id, "versionId", versionToActivate});
                this.sm.activeVersion = versionToActivate;
                this.savedModelsDAO.save(this.sm);
            }
            this.incrementSavedModelTrainIndex(tClient);
        }
        this.copyLogFile();
    }

    protected String getVersionToActivate() {
        return this.newVersionId;
    }

    protected void incrementSavedModelTrainIndex(TicketBasedIntercomAPIClient tClient) throws IOException {
        logger.info((Object)"Increment last train index");
        tClient.postFormToJSON("/dip/api/tintercom/savedmodels/increment-last-train", Void.class, new Object[]{"projectKey", this.recipe.getProjectKey(), "smId", this.sm.id, "jobId", JobContext.getCurrentJobContext().jobId});
        logger.info((Object)"Done Increment last train index");
    }

    protected void saveSavedModelOriginInfo(File outModelFolder, String generatingModelId, int jobIdUpdate) throws IOException {
        SavedModelOriginInfo smo = new SavedModelOriginInfo();
        smo.origin = SavedModelOriginInfo.Origin.TRAINED_FROM_RECIPE;
        smo.fullModelId = generatingModelId;
        smo.jobId = JobContext.getCurrentJobContext().jobId;
        smo.jobIdUpdate = jobIdUpdate;
        JSON.prettyToFile((Object)smo, (File)FullModelId.getSmOriginFile(outModelFolder));
    }

    protected void saveSavedModelOriginInfo(File outModelFolder, String generatingModelId) throws IOException {
        this.saveSavedModelOriginInfo(outModelFolder, generatingModelId, 0);
    }

    protected abstract ResolvedPreprocessingParams getPreprocessing();

    protected abstract PreTrainModelingParams getModeling();

    protected abstract ResolvedCoreParams resolveCoreParams(ContainerExecSelection var1);

    protected void prepareModelFolder(File outModelFolder, ContainerExecSelection containerSelection, ContainerExecRuntimeConfig containerConfig, File splitFolder, SplitDesc splitDesc) throws IOException {
        JSON.prettyToFile((Object)splitDesc, (File)new File(outModelFolder, "split/split.json"));
        JSON.prettyToFile((Object)this.getPreprocessing(), (File)new File(outModelFolder, "rpreprocessing_params.json"));
        JSON.prettyToFile((Object)this.getModeling(), (File)new File(outModelFolder, "rmodeling_params.json"));
        JSON.prettyToFile((Object)this.resolveCoreParams(containerSelection), (File)new File(outModelFolder, "core_params.json"));
        logger.info((Object)("Read modeling: \n " + JSON.prettyLog((Object)this.getModeling())));
    }

    protected void copyLogFile() {
        try {
            JobContext jac = JobContext.getCurrentJobContext();
            File activityLogFile = LogsService.getActivityLogFile(jac.projectKey, jac.jobId, this.activity.id());
            FileUtils.copyFile((File)activityLogFile, (File)new File(this.runFolder, "train.log"));
        }
        catch (DKUSecurityException | IOException e) {
            logger.warn((Object)"Could not copy the training log file to the saved model.", e);
        }
    }

    @Override
    public void notifyBeforeAborting() {
        if (this.abortableRunner != null) {
            this.abortableRunner.notifyBeforeAborting();
        }
    }

    protected void startRunner(InitializableAbortableRecipeRunner runner) throws Exception {
        SpringUtils.getInstance().autowire((Object)runner);
        runner.init();
        this.abortableRunner = runner;
        runner.run();
    }

    static {
        DKUMLUtils.loadClasses();
        ScriptStep.loadClass();
        logger = DKULogger.getLogger((String)"dku.recipes.train");
    }
}

