/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLDiagnostics;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.ModelVersioning;
import com.dataiku.dip.analysis.ml.prediction.PartitionedExtractService;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.prediction.StratifiedMetricsAggregator;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionTrainingRecipeRunner;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionSMMgmtService;
import com.dataiku.dip.analysis.ml.prediction.flow.PythonPredictionTrainRecipeSubrunner;
import com.dataiku.dip.analysis.ml.prediction.flow.TabularPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.core.ResolvedPreprocessingParams;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PredictionModelSnippetData;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.exec.ContainerRecipeParams;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.graph.FlowSavedModel;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.partitioning.PartitioningUtils;
import com.dataiku.dip.partitioning.StratifiedModelUtils;
import com.dataiku.dip.recipes.InitializableAbortableRecipeRunner;
import com.dataiku.dip.remoterun.RemoteRunsRegistry;
import com.dataiku.dip.rpc.TicketBasedIntercomAPIClient;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.sql.SchemaReader;
import com.dataiku.dip.utils.AutoCloseableLock;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NamedLock;
import com.dataiku.dip.utils.NotImplementedException;
import com.google.gson.reflect.TypeToken;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public abstract class TabularPredictionTrainingRecipeRunner
extends AbstractPredictionTrainingRecipeRunner {
    @Autowired
    private PartitionedExtractService partitionedExtractService;
    protected static final DKULogger logger = DKULogger.getLogger((String)"dku.recipes.prediction.train");

    protected TabularPredictionTrainingRecipeRunner(JobActivity activity) {
        super(activity);
    }

    protected abstract void checkBackendAndPredictionType();

    protected void performChecksBeforeRun(TabularPredictionTrainingRecipePayloadParams desc) throws Exception {
        this.checkBackendAndPredictionType();
    }

    protected abstract String getCommand();

    protected abstract TabularPredictionTrainingRecipePayloadParams getDesc();

    protected abstract RemoteRunsRegistry.ExecutionType getRemoteExecutionType();

    protected abstract InitializableAbortableRecipeRunner createUnpartitionedRunner(FullModelId var1, File var2, ContainerExecRuntimeConfig var3, ContainerExecSelection var4) throws Exception;

    protected PythonPredictionTrainRecipeSubrunner createPartitionedRunner(FullModelId fmi, File partOutModelFolder, ContainerExecRuntimeConfig containerConfig) throws Exception {
        PythonPredictionTrainRecipeSubrunner runner = this.createPythonPredictionTrainingSubrunner(fmi, partOutModelFolder, containerConfig);
        runner.disableAutomaticPermissionHandling();
        return runner;
    }

    private String getNewVersionId(FlowSavedModel fsm) {
        String ret = "" + System.currentTimeMillis();
        if (fsm.getSavedModel().isPartitioned()) {
            Partition partition = this.activity.getSubgraph().getTargetPartition(fsm);
            ret = ret + "_" + DigestUtils.md5Hex((String)partition.id()).substring(0, 8);
        }
        return ret;
    }

    protected PythonPredictionTrainRecipeSubrunner createPythonPredictionTrainingSubrunner(FullModelId fmi, File outModelFolder, ContainerExecRuntimeConfig containerConfig) {
        TabularPredictionTrainingRecipePayloadParams desc = this.getDesc();
        RemoteRunsRegistry.ExecutionType executionType = this.getRemoteExecutionType();
        JobContext.getCurrentActivitySummary().engineType = "DSS";
        return new PythonPredictionTrainRecipeSubrunner(this.activity, fmi, outModelFolder, containerConfig, desc, executionType, this.getCommand(), outModelFolder.getAbsolutePath(), desc.operationMode.toString());
    }

    private String getLatestVersionFromBackend(SavedModel sm, String projectKey) throws IOException {
        List versions;
        String secret = this.ticketService.getSingleTicket().getSecret();
        try (TicketBasedIntercomAPIClient tClient = TicketBasedIntercomAPIClient.forLocalHost(secret);){
            versions = (List)tClient.postFormToJSON("/dip/api/tintercom/savedmodels/list-versions", TypeToken.getParameterized(ArrayList.class, (Type[])new Type[]{PredictionSMMgmtService.PredictionSMVersionHeader.class}), new Object[]{"projectKey", projectKey, "savedModelId", sm.id});
        }
        return PredictionSMMgmtService.getLatestSMVersionId(versions);
    }

    private void combinePartsSplitDesc(SplitDesc globalSplitDesc, Set<FullModelId> partitionsFmis) throws IOException {
        globalSplitDesc.testRows = 0L;
        globalSplitDesc.trainRows = 0L;
        globalSplitDesc.fullRows = 0L;
        for (FullModelId partFmi : partitionsFmis) {
            SplitDesc partSplitDesc = partFmi.parseSessionFile("split/split.json", SplitDesc.class);
            StratifiedModelUtils.mergeSplitDesc(globalSplitDesc, partSplitDesc);
        }
    }

    @Override
    protected String getVersionToActivate() {
        if (this.sm.isPartitioned()) {
            return this.partitionedExtractService.get((String)JobContext.getCurrentJobContext().jobId, (SavedModel)this.sm).jobIdVersion;
        }
        return this.newVersionId;
    }

    @Override
    protected ResolvedPreprocessingParams getPreprocessing() {
        return this.getDesc().getPreprocessing();
    }

    @Override
    protected PreTrainModelingParams getModeling() {
        return this.getDesc().getPreTrainModelingParams();
    }

    @Override
    protected void prepareModelFolder(File modelFolder, ContainerExecSelection containerSelection, ContainerExecRuntimeConfig containerConfig, File splitFolder, SplitDesc splitDesc) throws IOException {
        TabularPredictionTrainingRecipePayloadParams desc = this.getDesc();
        if (desc.backendType.isPythonBased() && desc.getAssertionsParams() != null) {
            JSON.prettyToFile((Object)desc.getAssertionsParams(), (File)new File(modelFolder, "rassertions.json"));
        }
        if (desc.getOverridesParams() != null && desc.getOverridesParams().hasOverrides()) {
            JSON.prettyToFile((Object)desc.getOverridesParams(), (File)new File(modelFolder, "roverrides.json"));
        }
        super.prepareModelFolder(modelFolder, containerSelection, containerConfig, splitFolder, splitDesc);
    }

    @Override
    public void run() throws Exception {
        this.performChecksBeforeRun(this.getDesc());
        RecipeRunnableSubgraph subgraph = (RecipeRunnableSubgraph)this.activity.getSubgraph();
        FlowSavedModel fsm = (FlowSavedModel)subgraph.getTargets().get(0);
        this.sm = fsm.getSavedModel();
        this.newVersionId = this.getNewVersionId(fsm);
        File outModelFolder = MLPaths.savedModelVersionFolder(this.sm, this.newVersionId);
        MLPaths.createIfNeededSavedModelFolderAndRestrictPermissions(this.sm);
        DKUFileUtils.mkdirs((File)outModelFolder);
        FullModelId newFMI = new FullModelId(this.sm.projectKey, this.sm.id, this.newVersionId);
        List<FlowDataset> inputFDSs = subgraph.getSourceDatasetsForRole("main");
        Dataset inputDataset = inputFDSs.get(0).getMandatory(this.datasetsDAO);
        JSON.prettyToFile((Object)inputDataset.getSchema(), (File)new File(outModelFolder, "input_dataset_schema.json"));
        if (this.sm.isPartitioned()) {
            this.runPartitionedTraining(outModelFolder);
        } else {
            this.runUnpartitionedTraining(newFMI, outModelFolder);
        }
        this.activity.getTargetStatus((String)fsm.getFullId()).modelVersionId = this.newVersionId;
    }

    private void fillAndSaveUserMeta(File outModelFolder, SplitParams splitParams, RecipeRunnableSubgraph subgraph) throws Exception {
        TabularPredictionTrainingRecipePayloadParams desc = this.getDesc();
        ModelTrainInfo mti = (ModelTrainInfo)JSON.parseFile((File)new File(outModelFolder, "train_info.json"), ModelTrainInfo.class);
        PostTrainPredictionModelingParams resolved = (PostTrainPredictionModelingParams)this.updateTrainInfoWithPostSearchDescription(outModelFolder, mti, desc.getPreTrainModelingParams());
        List<FlowDataset> inputFDSs = subgraph.getSourceDatasetsForRole("main");
        String defaultInputDataset = inputFDSs.get(0).getMandatory(this.datasetsDAO).getFullName();
        ModelUserMeta mum = this.createUserMeta(splitParams, mti, desc.modelVersionNamePrefix, desc.getPreTrainModelingParams().generateName(), resolved.algorithm.name(), defaultInputDataset);
        if (desc.getCoreParams().prediction_type == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION) {
            BinaryClassificationModelPerf perf = (BinaryClassificationModelPerf)JSON.parseFile((File)new File(outModelFolder, "perf.json"), BinaryClassificationModelPerf.class);
            mum.activeClassifierThreshold = perf.usedThreshold;
        }
        JSON.prettyToFile((Object)mum, (File)new File(outModelFolder, "user_meta.json"));
    }

    private void runUnpartitionedTraining(FullModelId fmi, File outModelFolder) throws Exception {
        TabularPredictionTrainingRecipePayloadParams desc = this.getDesc();
        this.runFolder = outModelFolder;
        RecipeRunnableSubgraph subgraph = (RecipeRunnableSubgraph)this.activity.getSubgraph();
        File splitFolder = new File(outModelFolder, "split");
        DKUFileUtils.mkdirs((File)splitFolder);
        if (desc.script == null) {
            desc.script = new SerializedShakerScript();
        }
        String projectKey = this.recipe.getProjectKey();
        ContainerExecSelection containerSelection = this.recipe.getModel().getParamsAs(ContainerRecipeParams.class).getContainerSelection();
        ContainerExecRuntimeConfig containerConfig = new ContainerExecConfigSelector().selectForML_autoTXN(this.authCtxService.getAuthCtx(), projectKey, containerSelection, desc.backendType);
        SplitDesc splitDesc = this.predictionSplitService.prepareSplits(desc, splitFolder, false, this.newVersionId, subgraph);
        this.prepareModelFolder(outModelFolder, containerSelection, containerConfig, splitFolder, splitDesc);
        InitializableAbortableRecipeRunner runner = this.createUnpartitionedRunner(fmi, outModelFolder, containerConfig, containerSelection);
        this.startRunner(runner);
        MLDiagnostics.mergeIntoWarnings(new FullModelId(projectKey, this.sm.id, this.newVersionId), this.activity.warnContext);
        this.fillAndSaveUserMeta(outModelFolder, splitDesc.params, subgraph);
        ModelVersioning.dumpTrainVersionInfo(desc.backendType, outModelFolder);
        PredictionSMMgmtService.fillPipelineMeta(new FullModelId(this.sm.projectKey, this.sm.id, this.newVersionId));
        this.saveSavedModelOriginInfo(outModelFolder, desc.generatingModelId);
    }

    private void runPartitionedTraining(File outModelFolder) throws Exception {
        TabularPredictionTrainingRecipePayloadParams desc = this.getDesc();
        RecipeRunnableSubgraph subgraph = (RecipeRunnableSubgraph)this.activity.getSubgraph();
        FlowSavedModel fsm = (FlowSavedModel)subgraph.getTargets().get(0);
        long startTime = System.currentTimeMillis();
        String currentJobId = JobContext.getCurrentJobContext().jobId;
        if (desc.backendType == MLTask.BackendType.DEEP_HUB) {
            throw new IllegalArgumentException("Only tabular prediction supports partitions, got: " + desc.getClass().getSimpleName());
        }
        TabularPredictionTrainingRecipePayloadParams.PartitionedTrainingSource partSource = desc.partSource;
        String partSourceVersionId = desc.partSourceVersionId;
        Partition partition = this.activity.getSubgraph().getTargetPartition(fsm);
        String partitionName = partition.id();
        String encodedPartitionName = PartitioningUtils.encode(partitionName);
        String partVersionId = "" + System.currentTimeMillis();
        FullModelId partitionFmi = new FullModelId(this.sm.projectKey, this.sm.id, this.newVersionId, partitionName, partVersionId);
        File partOutModelFolder = MLPaths.savedModelPartVersionFolder(this.sm, encodedPartitionName, partVersionId);
        DKUFileUtils.mkdirs((File)partOutModelFolder);
        this.runFolder = partOutModelFolder;
        FullModelId globalModelId = new FullModelId(this.sm.projectKey, this.sm.id, this.newVersionId);
        String lockName = PartitionedExtractService.getJobLockName(currentJobId, this.sm);
        FullModelId sourceModelFmi = null;
        try (AutoCloseableLock lock = NamedLock.acquire((String)lockName);){
            if (partSource == null) {
                throw new IllegalArgumentException("Partitioned source cannot be null.");
            }
            switch (partSource) {
                case ACTIVE_VERSION: {
                    File activeModelFolder;
                    if (StringUtils.isNotBlank((String)this.sm.activeVersion) && (activeModelFolder = MLPaths.savedModelVersionFolder(this.sm, this.sm.activeVersion)).exists()) {
                        sourceModelFmi = new FullModelId(this.sm.projectKey, this.sm.id, this.sm.activeVersion);
                    }
                    if (sourceModelFmi != null) break;
                    logger.warnV("No active model version found for Saved Model '%s', training model from scratch", new Object[]{this.sm.id});
                    desc.partSource = TabularPredictionTrainingRecipePayloadParams.PartitionedTrainingSource.NONE;
                    break;
                }
                case LATEST_VERSION: {
                    String latestModelVersionId = this.getLatestVersionFromBackend(this.sm, this.recipe.getProjectKey());
                    if (latestModelVersionId == null) {
                        logger.warnV("No model version found for Saved Model '%s', training model from scratch", new Object[]{this.sm.id});
                        desc.partSource = TabularPredictionTrainingRecipePayloadParams.PartitionedTrainingSource.NONE;
                        break;
                    }
                    sourceModelFmi = new FullModelId(this.sm.projectKey, this.sm.id, latestModelVersionId);
                    break;
                }
                case EXPLICIT_VERSION: {
                    if (StringUtils.isBlank((String)partSourceVersionId)) {
                        throw new IllegalArgumentException("No explicit partitioned model source given");
                    }
                    File explicitModelFolder = MLPaths.savedModelVersionFolder(this.sm, partSourceVersionId);
                    if (!explicitModelFolder.exists()) {
                        throw new IllegalArgumentException("Partitioned model source version '" + partSourceVersionId + "' does not exist");
                    }
                    sourceModelFmi = new FullModelId(this.sm.projectKey, this.sm.id, partSourceVersionId);
                    break;
                }
                case NONE: {
                    break;
                }
                default: {
                    throw new NotImplementedException("Unsupported partitioned source type: " + partSource.name());
                }
            }
            boolean firstActivity = this.partitionedExtractService.createIfNeeded(currentJobId, this.sm, sourceModelFmi);
            this.grantRelevantPermissionForPartitionedTraining(firstActivity, globalModelId, partOutModelFolder);
        }
        File splitFolder = new File(partOutModelFolder, "split");
        DKUFileUtils.mkdirs((File)splitFolder);
        if (desc.script == null) {
            desc.script = new SerializedShakerScript();
        }
        ContainerExecSelection containerSelection = this.recipe.getModel().getParamsAs(ContainerRecipeParams.class).getContainerSelection();
        ContainerExecRuntimeConfig containerConfig = new ContainerExecConfigSelector().selectForML_autoTXN(this.authCtxService.getAuthCtx(), this.recipe.getProjectKey(), containerSelection, desc.backendType);
        SplitDesc splitDesc = this.predictionSplitService.prepareSplits(desc, splitFolder, !partition.isNP(), this.newVersionId, subgraph);
        if (partSource.hasSource() && sourceModelFmi != null) {
            File oldSplitFile = DKUFileUtils.getWithin((File)sourceModelFmi.getModelFolder(), (String[])new String[]{"split", "split.json"});
            logger.infoV("Source split %s", new Object[]{oldSplitFile.getAbsolutePath()});
            SplitDesc oldSplitDesc = (SplitDesc)JSON.parseFile((File)oldSplitFile, SplitDesc.class);
            try {
                SchemaReader.isSchemaCompatible(oldSplitDesc.schema, splitDesc.schema, "Saved model", "Deployed model", false);
            }
            catch (IllegalArgumentException ex) {
                logger.warn((Object)"Original saved model schema and newly generated one do not match.", (Throwable)ex);
            }
        }
        this.prepareModelFolder(partOutModelFolder, containerSelection, containerConfig, splitFolder, splitDesc);
        PreTrainPredictionModelingParams modeling = (PreTrainPredictionModelingParams)desc.getPreTrainModelingParams();
        logger.info((Object)("Read modeling: \n " + JSON.prettyLog((Object)modeling)));
        if (desc.backendType != MLTask.BackendType.PY_MEMORY) {
            throw new NotImplementedException("Unsupported backend type for partitioned model training: " + String.valueOf((Object)desc.backendType));
        }
        PythonPredictionTrainRecipeSubrunner runner = this.createPartitionedRunner(partitionFmi, partOutModelFolder, containerConfig);
        this.startRunner(runner);
        this.fillAndSaveUserMeta(partOutModelFolder, splitDesc.params, subgraph);
        ModelVersioning.dumpTrainVersionInfo(desc.backendType, partOutModelFolder);
        PredictionSMMgmtService.fillPipelineMeta(partitionFmi);
        DKUFileUtils.copyDirectory((File)partOutModelFolder, (File)outModelFolder, EnumSet.of(DKUFileUtils.CopyDirectoryFlags.NoReplacing), (FileFilter[])new FileFilter[]{f -> f.getName().endsWith(".json")});
        try (AutoCloseableLock lock = NamedLock.acquire((String)lockName);){
            PartitionedExtractService.TrainedFromJobModelInfo trainedFromJobModelInfo = this.partitionedExtractService.get(currentJobId, this.sm);
            PartitionedModelExtract extract = trainedFromJobModelInfo.extract;
            if (trainedFromJobModelInfo.jobIdVersion == null) {
                extract.setStatesToReused();
            }
            extract.versions.put(partitionName, partVersionId);
            ModelTrainInfo partMti = (ModelTrainInfo)JSON.parseFile((File)new File(partOutModelFolder, "train_info.json"), ModelTrainInfo.class);
            extract.setState(partitionName, partMti.state);
            try {
                PredictionModelDetails details = PredictionResultsReader.makeModelDetails(partitionFmi);
                PredictionModelSnippetData snippet = PredictionResultsReader.makeSnippet(details);
                PredictionResultsReader.addPartitionedModelInfo(snippet, partitionFmi);
                if (details.mlDiagnostics != null) {
                    details.mlDiagnostics.mergeIntoWarnings(this.activity.warnContext);
                }
                extract.summaries.put(partitionName, new PartitionedModelExtract.PartitionedModelSummary(snippet));
            }
            catch (IOException ex) {
                logger.warn((Object)"Failed to retrieve partitioned model snippet", (Throwable)ex);
            }
            Set<FullModelId> partitionsFmis = globalModelId.getSMPartitionsFmis(extract.versions);
            this.combinePartsSplitDesc(splitDesc, partitionsFmis);
            JSON.prettyToFile((Object)splitDesc, (File)new File(outModelFolder, "split/split.json"));
            List<ModelDetailsBase> modelDetailsList = StratifiedMetricsAggregator.retrievePerPartitionMetrics(globalModelId, partitionsFmis);
            ModelTrainInfo mti = globalModelId.parseModelFile("train_info.json", ModelTrainInfo.class);
            StratifiedMetricsAggregator.setTrainInfo(mti, startTime, System.currentTimeMillis());
            JSON.prettyToFile((Object)mti, (File)globalModelId.getModelFile("train_info.json"));
            StratifiedMetricsAggregator.computeAndSaveOverallMetrics(desc.getCoreParams().prediction_type, modelDetailsList, globalModelId);
            int newJobIdUpdate = trainedFromJobModelInfo.jobIdUpdate + 1;
            globalModelId.savePartFile(extract);
            this.saveSavedModelOriginInfo(outModelFolder, desc.generatingModelId, newJobIdUpdate);
            this.partitionedExtractService.set(currentJobId, this.sm, extract, this.newVersionId, newJobIdUpdate);
        }
        this.fillAndSaveUserMeta(outModelFolder, splitDesc.params, subgraph);
        ModelVersioning.dumpTrainVersionInfo(desc.backendType, outModelFolder);
        PredictionSMMgmtService.fillPipelineMeta(globalModelId);
    }

    private void grantRelevantPermissionForPartitionedTraining(boolean initialPartition, FullModelId globalModelId, File partOutModelFolder) throws IOException, DKUSecurityException, InterruptedException {
        if (initialPartition) {
            logger.info((Object)"Initial partition, granting read ACLs to Saved Model");
            FilesystemACLUtils.grantFSReadACLs(this.authCtxService.getAuthCtx(), this.sm.projectKey, globalModelId.getFolderEnsuringSecurity());
        }
        FilesystemACLUtils.grantFSFullACLs(this.authCtxService.getAuthCtx(), this.sm.projectKey, partOutModelFolder);
    }
}

