/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.common.server.APIError;
import com.dataiku.common.server.SerializedError;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.coreservices.AnalysisMLContainerKernel;
import com.dataiku.dip.analysis.coreservices.AnalysisMLKernel;
import com.dataiku.dip.analysis.coreservices.IAnalysisMLKernel;
import com.dataiku.dip.analysis.coreservices.MLBaseService;
import com.dataiku.dip.analysis.ml.DKUMLUtils;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.ml.spark.SparkBasedDoctorJob;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.cluster.SparkSettings;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.export.ZipUnzipDir;
import com.dataiku.dip.futures.DSSFuturePayloadUtils;
import com.dataiku.dip.futures.FutureAborter;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.io.PortRangeParams;
import com.dataiku.dip.io.SingleCommandKernelLink;
import com.dataiku.dip.io.SocketBlockLinkException;
import com.dataiku.dip.kernels.IDSSKernelBase;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.CurrentComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.process.IsolableProcess;
import com.dataiku.dip.security.rpc.EncryptedRPC;
import com.dataiku.dip.security.tickets.APITicketService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.controllers.NotFoundException;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.resources.ResourcesGatherer;
import com.dataiku.dip.spark.SparkJob;
import com.dataiku.dip.spark.SparkJobHelper;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.CollectionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import com.dataiku.dss.shadelib.org.apache.commons.io.FileUtils;
import com.google.common.collect.Lists;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;

public abstract class EnsembleHandler {
    private static final Logger logger = Logger.getLogger((String)"dku.analysis.prediction.ensembles");

    public abstract void runEnsemble() throws Exception;

    public abstract void finish() throws Exception;

    public abstract void abort() throws Exception;

    public abstract void init() throws Exception;

    public abstract int getKernelPid();

    static EnsembleHandler from(AuthCtx authCtx, AnalysisCoreParams acp, FullModelId fmid, List<FullModelId> childFmis, WorkSet.PreprocessingSet pps) throws IOException, NotFoundException {
        MLTask.BackendType backend = fmid.getHeadMLTask().backendType;
        switch (backend) {
            case H2O: 
            case MLLIB: {
                return new SparkEnsembleHandler(authCtx, fmid, acp, childFmis, pps);
            }
            case PY_MEMORY: {
                return new PythonEnsembleHandler(authCtx, fmid, childFmis, pps);
            }
        }
        throw new IllegalArgumentException("Rescoring not available on backend " + String.valueOf((Object)backend));
    }

    public static FuturePayload buildFuturePayload(FullModelId fmi) {
        FuturePayload fp = new FuturePayload();
        fp.action = "ensemble";
        fp.targets.add(DSSFuturePayloadUtils.forFMI(fmi).withPart("ensemble"));
        fp.displayName = "Ensemble scoring work";
        return fp;
    }

    static class SparkEnsembleHandler
    extends EnsembleHandler {
        @Autowired
        private VariablesService variablesService;
        @Autowired
        private APITicketService apiTicketService;
        @Autowired
        private TransactionService transactionService;
        @Autowired
        private DatasetsDAO datasetsDAO;
        private IsolableProcess process;
        private final FullModelId modelId;
        private final AnalysisCoreParams acp;
        private ResourcesGatherer gatherer = new ResourcesGatherer();
        private AuthCtx user;
        private final List<FullModelId> fmis;
        private final List<String> childPreprocessingFolders;
        private final List<String> childModelFolders;

        public SparkEnsembleHandler(AuthCtx ctx, FullModelId modelId, AnalysisCoreParams acp, List<FullModelId> fmis, WorkSet.PreprocessingSet pps) {
            this.modelId = modelId;
            this.user = ctx;
            this.acp = acp;
            this.fmis = fmis;
            this.childPreprocessingFolders = new ArrayList<String>();
            this.childModelFolders = new ArrayList<String>();
            for (FullModelId fmi : fmis) {
                this.childPreprocessingFolders.add(fmi.getPreprocessingFolder().getAbsolutePath());
                this.childModelFolders.add(fmi.getModelFolder().getAbsolutePath());
            }
        }

        private SplitDesc getSplitDesc() throws IOException {
            String splitRefPath = this.modelId.getSessionFile("split_ref.json").getAbsolutePath();
            String splitRef = ((SplitDesc.SplitRef)JSON.parseFile((String)splitRefPath, SplitDesc.SplitRef.class)).splitInstanceId;
            File splits = this.modelId.getTaskLoc().getSplitsFolder();
            return (SplitDesc)JSON.parseFile((File)new File(splits, splitRef + ".json"), SplitDesc.class);
        }

        @Override
        public void runEnsemble() throws Exception {
            String hiveDb;
            SplitDesc splitDesc = this.getSplitDesc();
            final PredictionMLTask task = (PredictionMLTask)this.modelId.getHeadMLTask();
            final File sessionFolder = this.modelId.getSessionFolder();
            JSON.prettyToFile((Object)new SplitDesc.SplitRef(splitDesc.instanceId), (File)new File(sessionFolder, "split_ref.json"));
            try (Transaction t = this.transactionService.beginRead();){
                SerializedShakerScript expandedScript = (SerializedShakerScript)JSON.deepCopy((Object)this.acp.script);
                this.gatherer.gatherAndCompute(this.user, this.acp.projectKey, this.acp.script.expandedDeepCopy((VariablesContext)this.variablesService.getForProject((String)this.acp.projectKey)).steps);
                JSON.prettyToFile((Object)splitDesc, (File)new File(sessionFolder, "split_desc.json"));
                JSON.prettyToFile((Object)expandedScript, (File)new File(sessionFolder, "escript.json"));
                JSON.prettyToFile(this.gatherer.getResourceMapping(), (File)new File(sessionFolder, "resource_mapping.json"));
                hiveDb = DKUMLUtils.getHiveDb(this.acp, (MLTask)task, this.datasetsDAO, splitDesc.params);
            }
            try (APITicketService.ExpirableTicket ticket = this.apiTicketService.createExpiringTicket(this.user, "MLLib doctor ensembling", (Object)task);){
                SparkBasedDoctorJob doctorJob = new SparkBasedDoctorJob(this.user, this.acp.projectKey, this.modelId.getPreprocessingFolder(), task, ticket);
                doctorJob.runSpark(new SparkBasedDoctorJob.SparkDoctorJobBuilder(){

                    @Override
                    public <T extends SparkJob> T buildSparkJob(SparkJobHelper<T> helper, File runDir, SparkSettings sparkSettings, List<SimpleKeyValue> effectiveConf) throws Exception {
                        return helper.makeClassJobWithNonSecretGlobalFiles("DSS (Analysis): " + task.name, effectiveConf, gatherer.getResourceFiles(), task.backendType == MLTask.BackendType.H2O, "com.dataiku.dip.spark.MLLibEnsemblingJob", acp.projectKey, sessionFolder.getAbsolutePath(), modelId.getModelFolder().getAbsolutePath(), modelId.getPreprocessingFolder().getAbsolutePath(), JSON.json(childModelFolders), JSON.json(childPreprocessingFolders));
                    }

                    @Override
                    public Map<String, String> getContextOverrideConf() {
                        return CollectionUtils.appendableSSMap().put("spark.dku.ml.preparedDF.storageLevel", task.sparkParams.sparkPreparedDFStorageLevel).put("spark.dku.ml.repartitionNonHDFS", String.valueOf(task.sparkParams.sparkRepartitionNonHDFS)).put("spark.dku.ml.useGlobalMetastore", Boolean.toString(task.sparkParams.sparkUseGlobalMetastore)).put("spark.dku.ml.hiveDb", StringUtils.defaultIfBlank((String)hiveDb, (String)"")).get();
                    }

                    @Override
                    public List<File> getExtraRecursiveFolders() {
                        ArrayList dirs = Lists.newArrayList((Object[])new File[]{sessionFolder, modelId.getPreprocessingFolder(), modelId.getModelFolder()});
                        for (FullModelId fmi : fmis) {
                            dirs.addAll(Lists.newArrayList((Object[])new File[]{fmi.getPreprocessingFolder(), fmi.getModelFolder()}));
                        }
                        return dirs;
                    }

                    @Override
                    public List<String> getWritablePaths() {
                        return Lists.newArrayList((Object[])new String[]{modelId.getModelFolder().getAbsolutePath()});
                    }
                }, new SparkJobHelper.SparkJobPostProcessor(){

                    @Override
                    public void postProcess(SparkJobHelper.SparkJobContext context) throws Exception {
                        if (context.driverRunsRemotely()) {
                            ZipUnzipDir.extractFolder(new File(modelId.getModelFolder(), "trainedModel"), modelId.getModelFolder());
                        }
                    }
                });
                JSON.prettyToFile((Object)new JsonObject(), (File)new File(this.modelId.getPreprocessingFolder(), "preprocessing_report.json"));
            }
        }

        @Override
        public void abort() throws Exception {
            if (this.process != null) {
                try {
                    this.process.niceThenEvilKill();
                }
                catch (IOException e) {
                    logger.error((Object)"Killing failed", (Throwable)e);
                }
            }
        }

        @Override
        public void init() throws Exception {
            SpringUtils.getInstance().autowire((Object)this);
            SpringUtils.getInstance().autowire((Object)this.gatherer);
        }

        @Override
        public void finish() throws Exception {
        }

        @Override
        public int getKernelPid() {
            if (this.process != null) {
                return this.process.getWorkingPid();
            }
            return 0;
        }
    }

    private static class PythonEnsembleHandler
    extends EnsembleHandler {
        private final FullModelId modelId;
        private final List<String> childPreprocessingFolders;
        private final List<String> childModelFolders;
        private final List<FullModelId> childFmis;
        private final WorkSet.PreprocessingSet pps;
        private final AuthCtx authCtx;
        private IAnalysisMLKernel kernel;
        private SingleCommandKernelLink link;
        private boolean runsInContainer = false;
        private boolean hasError = false;

        public PythonEnsembleHandler(AuthCtx authCtx, FullModelId modelId, List<FullModelId> childFmis, WorkSet.PreprocessingSet pps) {
            this.authCtx = authCtx;
            this.modelId = modelId;
            this.childFmis = childFmis;
            this.pps = pps;
            this.childPreprocessingFolders = new ArrayList<String>(childFmis.size());
            this.childModelFolders = new ArrayList<String>(childFmis.size());
            for (FullModelId fmi : childFmis) {
                this.childPreprocessingFolders.add(fmi.getPreprocessingFolder().getAbsolutePath());
                this.childModelFolders.add(fmi.getModelFolder().getAbsolutePath());
            }
            SpringUtils.getInstance().autowire((Object)this);
        }

        @Override
        public void init() throws Exception {
            logger.info((Object)"Acquiring a kernel");
            PortRangeParams dssPortRange = ApplicationConfigurator.getPortRangeParams();
            ResolvedCoreParams coreParams = this.modelId.getResolvedCoreParams();
            String envName = coreParams.executionParams.envName;
            ContainerExecRuntimeConfig containerConfig = new ContainerExecConfigSelector().selectForML_autoTXN(this.authCtx, this.modelId.getProjectKey(), coreParams.executionParams.containerSelection, coreParams.backendType);
            this.runsInContainer = containerConfig != null;
            this.link = new SingleCommandKernelLink(SecretKeyGenerator.generate((int)16), dssPortRange, this.runsInContainer ? EncryptedRPC.getSSLContext() : null);
            MLTaskLoc taskLoc = this.modelId.getTaskLoc();
            ComputeResourceUsageContext cruContext = ComputeResourceUsageContext.forAnalysisMLTrain((AuthCtx)this.authCtx, (String)taskLoc.analysisProjectKey, (String)taskLoc.analysisId, (String)taskLoc.mlTaskId, (String)this.modelId.getSessionId());
            CurrentComputeResourceUsageContext.setInCurrentThread((ComputeResourceUsageContext)cruContext);
            if (!this.runsInContainer) {
                this.kernel = new AnalysisMLKernel(this.link, this.pps, this.modelId.getProjectKey(), envName, this.authCtx, new File(this.pps.run_folder));
            } else {
                AnalysisMLContainerKernel containerKernel = new AnalysisMLContainerKernel(this.link, this.pps, this.modelId.getProjectKey(), envName, this.authCtx, new File(this.pps.run_folder), this.modelId.getTaskLoc().getDataFolder(), containerConfig, "doctor-ensemble-");
                this.kernel = containerKernel;
            }
            try {
                this.kernel.start();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw e;
            }
            finally {
                this.kernel = null;
            }
        }

        @Override
        public void finish() throws Exception {
            Exception error = null;
            try {
                if (this.kernel != null) {
                    if (!this.kernel.isAborted()) {
                        if (!this.hasError) {
                            try {
                                this.kernel.waitForResults();
                                ModelStateHelper.updatePredictionTrainInfoAndUserMeta(((PredictionMLTask)this.modelId.getHeadMLTask()).predictionType, this.pps);
                            }
                            catch (Exception e) {
                                error = e;
                            }
                        } else {
                            SerializedError serializedError = this.kernel.waitForError();
                            if (serializedError != null) {
                                error = new APIError.SerializedErrorException(serializedError);
                            }
                        }
                    }
                    this.kernel.cleanup();
                    this.kernel.killWithoutMercy();
                }
            }
            catch (Exception e) {
                logger.error((Object)"Failure while destroying ml kernel", (Throwable)e);
            }
            finally {
                this.kernel = null;
            }
            try {
                this.link.close();
            }
            catch (Exception e) {
                logger.error((Object)"Failure while closing link to kernel", (Throwable)e);
            }
            if (error != null) {
                throw error;
            }
        }

        @Override
        public int getKernelPid() {
            return this.kernel == null ? 0 : this.kernel.getPid();
        }

        private String getCoreParams() throws IOException {
            return JSON.json((Object)this.modelId.getResolvedCoreParams());
        }

        @Override
        public void abort() throws Exception {
            if (this.kernel != null) {
                this.kernel.abort();
            }
        }

        @Override
        public void runEnsemble() throws Exception {
            FileUtils.copyFile((File)this.childFmis.get(0).getPreprocessingFile("collector_data.json"), (File)this.modelId.getPreprocessingFile("collector_data.json"));
            JsonObject command = new JsonObject();
            command.addProperty("split_desc", JSON.json((Object)this.modelId.getSplitDesc()));
            command.addProperty("core_params", this.getCoreParams());
            command.addProperty("model_folder", this.modelId.getModelFolder().getAbsolutePath());
            command.addProperty("preprocessing_folder", this.modelId.getPreprocessingFolder().getAbsolutePath());
            command.addProperty("split_folder", this.modelId.getSplitFolder().getAbsolutePath());
            command.addProperty("model_folders", JSON.json(this.childModelFolders));
            command.addProperty("preprocessing_folders", JSON.json(this.childPreprocessingFolders));
            try {
                this.link.executeAsync((Object)new AnalysisMLKernel.ComputeRequest("create_ensemble", JSON.json((Object)command)), null, String.class, "Failed to ensemble").call();
            }
            catch (SocketBlockLinkException e) {
                this.hasError = true;
                throw this.kernel.maybeRethrowAsProcessDied((IOException)((Object)e.withLogTail((IDSSKernelBase)this.kernel)));
            }
        }
    }

    public static class EnsembleWorkThread
    extends MLBaseService.MLTaskWorkThread {
        @Autowired
        private MLBaseService mlBaseService;
        private final FullModelId modelId;
        private final List<FullModelId> childFmis;
        private Integer pid = 0;
        private final WorkSet ws;

        public EnsembleWorkThread(DSSAuthCtx user, MLTaskLoc loc, AnalysisCoreParams acp, WorkSet ws, FullModelId modelId, List<FullModelId> childFmis) {
            super(user, loc, acp);
            this.modelId = modelId;
            this.childFmis = childFmis;
            this.ws = ws;
            SpringUtils.getInstance().autowire((Object)this);
        }

        @Override
        public FuturePayload buildFuturePayload(MLTaskLoc loc, AnalysisCoreParams cp) {
            return MLBaseService.buildBaseFuturePayload(loc, "Ensembling prediction models", "train", "prediction");
        }

        @Override
        public void init() throws Exception {
            try {
                this.mlBaseService.startLogAppender();
            }
            catch (Exception e) {
                this.mlBaseService.setNotWorking(this.loc);
                throw e;
            }
        }

        @Override
        public List<FullModelId> modelsBeingTrained() {
            return Lists.newArrayList((Object[])new FullModelId[]{this.modelId});
        }

        @Override
        public double getDangerosity() {
            return 0.0;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public List<Integer> getPids() {
            EnsembleWorkThread ensembleWorkThread = this;
            synchronized (ensembleWorkThread) {
                return Lists.newArrayList((Object[])new Integer[]{this.pid});
            }
        }

        public Void getResult() {
            return null;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         * Loose catch block
         */
        public void execute() throws Exception {
            block45: {
                Throwable error;
                block44: {
                    mlTaskContext.set(new MLBaseService.MLTaskWorkThread.MLTaskContext(Collections.singletonList(this.modelId)));
                    logger.info((Object)"******************************************");
                    logger.info((Object)("** Start train session " + this.modelId.getSessionId()));
                    logger.info((Object)"******************************************");
                    error = null;
                    Object handler = EnsembleHandler.from(this.owner, this.cp, this.modelId, this.childFmis, this.ws.preprocessingSets.get(0));
                    EnsembleWorkThread ensembleWorkThread = this;
                    synchronized (ensembleWorkThread) {
                        handler.init();
                        this.pid = handler.getKernelPid();
                    }
                    try (FutureAborter.AutoCloseableAbortHook aborter = FutureAborter.pushAutoCloseableHook((Runnable)new Runnable(){
                        final /* synthetic */ EnsembleHandler val$handler;
                        {
                            this.val$handler = ensembleHandler;
                        }

                        @Override
                        public void run() {
                            try {
                                this.val$handler.abort();
                            }
                            catch (Exception e) {
                                logger.warn((Object)"Error white aborting the ensembling task", (Throwable)e);
                            }
                        }
                    });){
                        handler.runEnsemble();
                        ModelStateHelper.updatePredictionTrainInfoAndUserMeta(((PredictionMLTask)this.modelId.getHeadMLTask()).predictionType, this.ws.preprocessingSets.get(0));
                    }
                    finally {
                        handler.finish();
                    }
                    if (error == null) break block44;
                    handler = this;
                    synchronized (handler) {
                        for (WorkSet.PreprocessingSet pps : this.ws.preprocessingSets) {
                            ModelStateHelper.markAllNotDoneAsFailed(pps, error);
                        }
                    }
                }
                this.mlBaseService.setNotWorking(this.loc);
                try {
                    FileUtils.touch((File)new File(this.modelId.getSessionFolder(), "done.txt"));
                }
                catch (IOException e) {
                    logger.error((Object)"Failed to mark session as complete", (Throwable)e);
                }
                mlTaskContext.set(null);
                break block45;
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    logger.error((Object)"Processing failed", (Throwable)e);
                    error = e;
                    break block45;
                }
                catch (Throwable e2222222222) {
                    block46: {
                        this.mlBaseService.setQueueState(this.loc, false);
                        logger.error((Object)"Processing failed", e2222222222);
                        error = e2222222222;
                        if (error == null) break block46;
                        EnsembleWorkThread e2222222222 = this;
                        synchronized (e2222222222) {
                            for (WorkSet.PreprocessingSet pps : this.ws.preprocessingSets) {
                                ModelStateHelper.markAllNotDoneAsFailed(pps, error);
                            }
                        }
                    }
                    this.mlBaseService.setNotWorking(this.loc);
                    try {
                        FileUtils.touch((File)new File(this.modelId.getSessionFolder(), "done.txt"));
                    }
                    catch (IOException e3) {
                        logger.error((Object)"Failed to mark session as complete", (Throwable)e3);
                    }
                    mlTaskContext.set(null);
                    break block45;
                    {
                        catch (Throwable throwable) {
                            throw throwable;
                        }
                    }
                }
                finally {
                    if (error != null) {
                        EnsembleWorkThread e = this;
                        synchronized (e) {
                            for (WorkSet.PreprocessingSet pps : this.ws.preprocessingSets) {
                                ModelStateHelper.markAllNotDoneAsFailed(pps, error);
                            }
                        }
                    }
                    this.mlBaseService.setNotWorking(this.loc);
                    try {
                        FileUtils.touch((File)new File(this.modelId.getSessionFolder(), "done.txt"));
                    }
                    catch (IOException e) {
                        logger.error((Object)"Failed to mark session as complete", (Throwable)e);
                    }
                    mlTaskContext.set(null);
                }
            }
        }

        public String getModelId() {
            return this.modelId.toString();
        }
    }
}

