/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.split;

import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.split.ForcedSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitUtils;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamColumnFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.UniversalSingleThreadPusher;
import com.dataiku.dip.input.utils.CountingProcessorOutput;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.server.ShakerStreamService;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.util.EnumSet;
import org.springframework.beans.factory.annotation.Autowired;

public class RSDForcedSplitGenerator
implements ForcedSplitGenerator {
    @Autowired
    private ShakerStreamService shakerStreamService;
    private final AuthCtx authCtx;
    private final Dataset dataset;
    private final SplitParams params;
    private final File targetFolder;
    private final SerializedShakerScript script;
    private final Schema preparationOutputSchema;
    private final AbstractPredictionTrainingRecipePayloadParams.OperationMode operationMode;

    public RSDForcedSplitGenerator(AuthCtx authCtx, Dataset dataset, SplitParams params, SerializedShakerScript script, Schema preparationOutputSchema, File targetFolder, AbstractPredictionTrainingRecipePayloadParams.OperationMode operationMode) {
        SpringUtils.getInstance().autowire((Object)this);
        assert (params.ttPolicy == SplitParams.TrainTestPolicy.SPLIT_SINGLE_DATASET);
        this.authCtx = authCtx;
        this.preparationOutputSchema = preparationOutputSchema;
        this.dataset = dataset;
        this.params = params;
        this.script = script;
        this.targetFolder = targetFolder;
        this.operationMode = operationMode;
    }

    @Override
    public SplitDesc compute() throws Exception {
        StreamRowFactory rf;
        StreamColumnFactory cf;
        SplitDesc newDesc = new SplitDesc();
        newDesc.format = "csv1";
        newDesc.generationDate = System.currentTimeMillis();
        newDesc.params = (SplitParams)JSON.deepCopy((Object)this.params);
        newDesc.params.ssdDatasetSmartName = this.dataset.getSmartName(this.script.contextProjectKey);
        newDesc.schema = this.preparationOutputSchema;
        if (!this.params.streamAll && EnumSet.of(AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL, AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_ONLY).contains((Object)this.operationMode)) {
            cf = new StreamColumnFactory();
            rf = new StreamRowFactory();
            File trainPath = SplitUtils.getSavedModelTrainSetFile(this.targetFolder);
            CountingProcessorOutput trainWriter = SplitUtils.getWriterToSingleFile(trainPath, newDesc.schema, (ColumnFactory)cf);
            ProcessorOutputToSIP trainPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, this.dataset.getProjectKey(), this.script, (ProcessorOutput)trainWriter, (ColumnFactory)cf, (RowFactory)rf);
            File testPath = SplitUtils.getSavedModelTestSetFile(this.targetFolder);
            CountingProcessorOutput testWriter = SplitUtils.getWriterToSingleFile(testPath, newDesc.schema, (ColumnFactory)cf);
            ProcessorOutputToSIP testPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, this.dataset.getProjectKey(), this.script, (ProcessorOutput)testWriter, (ColumnFactory)cf, (RowFactory)rf);
            SplitUtils.RandomSplitter splitter = new SplitUtils.RandomSplitter((ProcessorOutput)trainPipeline, (ProcessorOutput)testPipeline, this.params.ssdSeed, this.params.ssdTrainingRatio);
            UniversalSingleThreadPusher ustp = new UniversalSingleThreadPusher(this.authCtx, this.dataset, splitter, (ColumnFactory)cf, (RowFactory)rf);
            ustp.setDatasetSelection(this.params.ssdSelection);
            ustp.push();
            splitter.lastRowEmitted();
            newDesc.trainPath = trainPath.getName();
            newDesc.testPath = testPath.getName();
            newDesc.trainRows = trainWriter.getCount();
            newDesc.testRows = testWriter.getCount();
        }
        if (this.params.streamAll || EnumSet.of(AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_FULL_ONLY, AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL, AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_KFOLD).contains((Object)this.operationMode)) {
            cf = new StreamColumnFactory();
            rf = new StreamRowFactory();
            File fullPath = SplitUtils.getSavedModelFullSetFile(this.targetFolder);
            CountingProcessorOutput fullWriter = SplitUtils.getWriterToSingleFile(fullPath, newDesc.schema, (ColumnFactory)cf);
            ProcessorOutputToSIP fullPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, this.dataset.getProjectKey(), this.script, (ProcessorOutput)fullWriter, (ColumnFactory)cf, (RowFactory)rf);
            UniversalSingleThreadPusher ustp = new UniversalSingleThreadPusher(this.authCtx, this.dataset, (ProcessorOutput)fullPipeline, (ColumnFactory)cf, (RowFactory)rf);
            ustp.setDatasetSelection(this.params.ssdSelection);
            ustp.push();
            fullPipeline.lastRowEmitted();
            newDesc.fullPath = fullPath.getName();
            newDesc.fullRows = fullWriter.getCount();
        }
        return newDesc;
    }
}

