/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.clustering;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.clustering.ClusteringRescoringHandler;
import com.dataiku.dip.analysis.ml.spark.SparkBasedDoctorJob;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.cluster.SparkSettings;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.export.ZipUnzipDir;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.process.IsolableProcess;
import com.dataiku.dip.security.tickets.APITicketService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.resources.ResourcesGatherer;
import com.dataiku.dip.spark.SparkJob;
import com.dataiku.dip.spark.SparkJobHelper;
import com.dataiku.dip.utils.CollectionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.variables.VariablesService;
import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;

public class MLLibClusteringRescoringHandler
extends ClusteringRescoringHandler {
    @Autowired
    private VariablesService variablesService;
    @Autowired
    private APITicketService apiTicketService;
    private final FullModelId modelId;
    private final AuthCtx authCtx;
    private IsolableProcess process;
    private static Logger logger = Logger.getLogger((String)"dku.analysis.prediction");

    public MLLibClusteringRescoringHandler(AuthCtx authCtx, FullModelId modelId) {
        this.modelId = modelId;
        this.authCtx = authCtx;
        SpringUtils.getInstance().autowire((Object)this);
    }

    @Override
    public void abort() {
        if (this.process != null) {
            try {
                this.process.niceThenEvilKill();
            }
            catch (IOException e) {
                logger.error((Object)"Killing failed", (Throwable)e);
            }
        }
    }

    @Override
    public void rescore() throws Exception {
        final MLTask task = this.modelId.getHeadMLTask();
        SerializedShakerScript expandedScript = (SerializedShakerScript)JSON.parseFile((File)this.modelId.getSessionFile("script.json"), SerializedShakerScript.class);
        expandedScript = expandedScript.expandedDeepCopy(this.variablesService.getForProject(this.modelId.getProjectKey()));
        final ResourcesGatherer gatherer = new ResourcesGatherer();
        SpringUtils.getInstance().autowire((Object)gatherer);
        gatherer.gatherAndCompute(this.authCtx, this.modelId.getProjectKey(), expandedScript.steps);
        JSON.prettyToFile(gatherer.getResourceMapping(), (File)new File(this.modelId.getSessionFolder(), "resource_mapping.json"));
        try (APITicketService.ExpirableTicket ticket = this.apiTicketService.createExpiringTicket(this.authCtx, "MLLib doctor cluster rescoring", (Object)task);){
            SparkBasedDoctorJob doctorJob = new SparkBasedDoctorJob(this.authCtx, this.modelId.getProjectKey(), this.modelId.getPreprocessingFolder(), task, ticket);
            doctorJob.runSpark(new SparkBasedDoctorJob.SparkDoctorJobBuilder(){

                @Override
                public <T extends SparkJob> T buildSparkJob(SparkJobHelper<T> helper, File runDir, SparkSettings sparkSettings, List<SimpleKeyValue> effectiveConf) throws Exception {
                    return helper.makeClassJobWithNonSecretGlobalFiles("DSS (Analysis): rescoring of " + MLLibClusteringRescoringHandler.this.modelId.toString(), effectiveConf, gatherer.getResourceFiles(), task.backendType == MLTask.BackendType.H2O, "com.dataiku.dip.spark.ClusteringRescoringDoctorJob", MLLibClusteringRescoringHandler.this.modelId.getProjectKey(), MLLibClusteringRescoringHandler.this.modelId.getSessionFolder().getAbsolutePath(), MLLibClusteringRescoringHandler.this.modelId.getPreprocessingFolder().getAbsolutePath(), MLLibClusteringRescoringHandler.this.modelId.getModelFolder().getAbsolutePath());
                }

                @Override
                public Map<String, String> getContextOverrideConf() {
                    return CollectionUtils.appendableSSMap().put("spark.dku.ml.preparedDF.storageLevel", task.sparkParams.sparkPreparedDFStorageLevel).put("spark.dku.ml.repartitionNonHDFS", String.valueOf(task.sparkParams.sparkRepartitionNonHDFS)).get();
                }

                @Override
                public List<File> getExtraRecursiveFolders() {
                    return Lists.newArrayList((Object[])new File[]{MLLibClusteringRescoringHandler.this.modelId.getSessionFolder(), MLLibClusteringRescoringHandler.this.modelId.getPreprocessingFolder(), MLLibClusteringRescoringHandler.this.modelId.getModelFolder()});
                }

                @Override
                public List<String> getWritablePaths() {
                    return Lists.newArrayList((Object[])new String[]{MLLibClusteringRescoringHandler.this.modelId.getModelFolder().getAbsolutePath()});
                }
            }, new SparkJobHelper.SparkJobPostProcessor(){

                @Override
                public void postProcess(SparkJobHelper.SparkJobContext context) throws Exception {
                    if (context.driverRunsRemotely()) {
                        ZipUnzipDir.extractFolder(new File(MLLibClusteringRescoringHandler.this.modelId.getModelFolder(), "trainedModel"), MLLibClusteringRescoringHandler.this.modelId.getModelFolder());
                    }
                }
            });
        }
    }
}

