/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.split.EFDForcedSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.ForcedSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.RSDForcedSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.SSDForcedSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.JobAuthCtxService;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.graph.FlowComputable;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.graph.FlowRecipe;
import com.dataiku.dip.datasets.DatasetSelection;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;

public class PredictionSplitService {
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private JobAuthCtxService authCtxService;

    SplitDesc prepareSplits(AbstractPredictionTrainingRecipePayloadParams desc, File splitFolder, boolean partitionedModel, String newVersionId, RecipeRunnableSubgraph subgraph) throws Exception {
        boolean forced;
        List<FlowDataset> inputFDSs = subgraph.getSourceDatasetsForRole(new String[]{"main", "test"});
        SplitDesc sd = null;
        boolean bl = forced = !desc.backendType.isSparkBased();
        if (!forced) {
            sd = new SplitDesc();
            sd.policyId = "none";
            sd.instanceId = "unique-" + newVersionId;
        }
        if (partitionedModel && desc.splitParams.ttPolicy != SplitParams.TrainTestPolicy.SPLIT_SINGLE_DATASET) {
            throw new NotImplementedException("Partitioned models are not implemented for this split policy " + String.valueOf((Object)desc.splitParams.ttPolicy));
        }
        SplitParams sp = (SplitParams)JSON.deepCopy((Object)desc.splitParams);
        ForcedSplitGenerator fsg = null;
        switch (desc.splitParams.ttPolicy) {
            case SPLIT_SINGLE_DATASET: {
                if (inputFDSs.size() == 0) {
                    throw ErrorContext.iae((String)"No dataset input in training recipe");
                }
                FlowDataset inputFDS = inputFDSs.get(0);
                Dataset inputDataset = inputFDS.getMandatory(this.datasetsDAO);
                PredictionSplitService.selectPartitions(inputDataset, sp.ssdSelection, inputFDS, subgraph);
                if (!forced) break;
                if (desc.splitParams.ssdSplitMode == SplitParams.SplitMode.RANDOM || desc.splitParams.streamAll) {
                    fsg = new RSDForcedSplitGenerator(this.authCtxService.getAuthCtx(), inputDataset, sp, desc.script, desc.expectedPreparationOutputSchema, splitFolder, desc.operationMode);
                    break;
                }
                if (desc.splitParams.ssdSplitMode == SplitParams.SplitMode.SORTED) {
                    fsg = new SSDForcedSplitGenerator(this.authCtxService.getAuthCtx(), inputDataset, sp, desc.script, desc.expectedPreparationOutputSchema, splitFolder, desc.operationMode);
                    break;
                }
                throw new NotImplementedException("Invalid split mode: " + String.valueOf((Object)desc.splitParams.ssdSplitMode));
            }
            case EXPLICIT_FILTERING_SINGLE_DATASET: {
                if (inputFDSs.size() == 0) {
                    throw ErrorContext.iae((String)"No dataset input in training recipe");
                }
                FlowDataset inputFDS = inputFDSs.get(0);
                Dataset inputDataset = inputFDS.getMandatory(this.datasetsDAO);
                PredictionSplitService.selectPartitions(inputDataset, sp.efsdTrain.selection, inputFDS, subgraph);
                PredictionSplitService.selectPartitions(inputDataset, sp.efsdTest.selection, inputFDS, subgraph);
                if (!forced) break;
                fsg = new EFDForcedSplitGenerator(this.authCtxService.getAuthCtx(), inputDataset, inputDataset, sp, desc.script, desc.expectedPreparationOutputSchema, splitFolder);
                break;
            }
            case EXPLICIT_FILTERING_TWO_DATASETS: {
                if (inputFDSs.size() < 2) {
                    throw ErrorContext.iae((String)"Missing dataset input in training recipe");
                }
                FlowRecipe recipe = subgraph.getRecipe();
                SerializedRecipe.RecipeInput trainRef = recipe.getModel().getSingleInput("main");
                FlowDataset trainFDS = subgraph.getSourceDataset(trainRef.getLoc(recipe.getProjectKey()).getFullName());
                Dataset trainDataset = trainFDS.getMandatory(this.datasetsDAO);
                SerializedRecipe.RecipeInput testRef = recipe.getModel().getSingleInput("test");
                FlowDataset testFDS = subgraph.getSourceDataset(testRef.getLoc(recipe.getProjectKey()).getFullName());
                Dataset testDataset = testFDS.getMandatory(this.datasetsDAO);
                PredictionSplitService.selectPartitions(trainDataset, sp.eftdTrain.selection, trainFDS, subgraph);
                PredictionSplitService.selectPartitions(testDataset, sp.eftdTest.selection, testFDS, subgraph);
                if (!forced) break;
                fsg = new EFDForcedSplitGenerator(this.authCtxService.getAuthCtx(), trainDataset, testDataset, sp, desc.script, desc.expectedPreparationOutputSchema, splitFolder);
                break;
            }
            default: {
                throw new NotImplementedException();
            }
        }
        if (forced) {
            assert (fsg != null);
            sd = fsg.compute();
        } else {
            assert (sd != null);
            sd.params = sp;
            sd.generationDate = System.currentTimeMillis();
            sd.schema = desc.expectedPreparationOutputSchema;
        }
        return sd;
    }

    SplitDesc getExpandedSplit(AbstractPredictionTrainingRecipePayloadParams desc, SplitDesc splitDesc, File splitFolder, ContainerExecRuntimeConfig predictionContainerConfig) {
        SplitDesc expandedSplitDesc = (SplitDesc)JSON.deepCopy((Object)splitDesc);
        if (expandedSplitDesc.trainPath != null) {
            File trainPath = new File(splitFolder, expandedSplitDesc.trainPath);
            File testPath = new File(splitFolder, expandedSplitDesc.testPath);
            if (predictionContainerConfig == null) {
                expandedSplitDesc.trainPath = trainPath.getAbsolutePath();
                expandedSplitDesc.testPath = testPath.getAbsolutePath();
            } else {
                expandedSplitDesc.trainPath = "split/" + expandedSplitDesc.trainPath;
                expandedSplitDesc.testPath = "split/" + expandedSplitDesc.testPath;
            }
        }
        if (expandedSplitDesc.fullPath != null) {
            File fullPath = new File(splitFolder, expandedSplitDesc.fullPath);
            expandedSplitDesc.fullPath = predictionContainerConfig == null ? fullPath.getAbsolutePath() : "split/" + expandedSplitDesc.fullPath;
        }
        expandedSplitDesc.schema = desc.expectedPreparationOutputSchema;
        return expandedSplitDesc;
    }

    private static void selectPartitions(Dataset dataset, DatasetSelection selection, FlowComputable source, RecipeRunnableSubgraph subgraph) {
        if (dataset.getPartitioningSchema().isPartitioned()) {
            selection.partitionSelectionMethod = DatasetSelection.PartitionSelectionMethod.SELECTED;
            selection.selectedPartitions = new ArrayList<String>();
            for (Partition p : subgraph.getSourcePartitions(source)) {
                selection.selectedPartitions.add(p.id());
            }
        }
    }
}

