/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml;

import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.MLTaskHandler;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.ml.shared.PRNSTrainThread;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.distributed.metrics.ContainerUsageMetrics;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public abstract class PythonMLTaskHandler<T extends MLTask>
extends MLTaskHandler<T> {
    protected final List<PRNSTrainThread> processingThreads = new ArrayList<PRNSTrainThread>();
    private boolean aborting = false;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis");

    public PythonMLTaskHandler(AnalysisCoreParams acp, MLTaskLoc taskLoc, T task, String sessionId, AuthCtx user) {
        super(acp, taskLoc, task, sessionId, user);
    }

    @Override
    public synchronized void abort() {
        this.aborting = true;
        logger.info((Object)("Aborting " + this.sessionId));
        PRNSTrainThread.abort(this.processingThreads);
    }

    public synchronized boolean isAborting() {
        return this.aborting;
    }

    @Override
    public synchronized void abort(List<FullModelId> fullModelIdSet) throws IOException {
        logger.info((Object)("Partial abort " + this.sessionId));
        for (PRNSTrainThread tat : this.processingThreads) {
            tat.partialAbort(fullModelIdSet);
        }
    }

    protected abstract PRNSTrainThread createTrainThread(SplitDesc var1);

    protected abstract SplitDesc prepareSplits() throws Exception;

    protected abstract void checkSplits(SplitDesc var1);

    protected abstract String getDSSMetricName();

    protected abstract int getThreadCountToRun();

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void train() throws Exception {
        DSSMetrics.registry().meter(this.getDSSMetricName()).mark();
        try {
            SplitDesc splitDesc = this.prepareSplits();
            JSON.prettyToFile((Object)new SplitDesc.SplitRef(splitDesc.instanceId), (File)new File(MLPaths.sessionFolder(this.taskLoc, this.sessionId), "split_ref.json"));
            this.dispatchWork();
            this.runTraining(splitDesc);
        }
        catch (Exception e) {
            logger.error((Object)"Failure while training main loop", (Throwable)e);
            PythonMLTaskHandler pythonMLTaskHandler = this;
            synchronized (pythonMLTaskHandler) {
                for (WorkSet.PreprocessingSet pps : this.ws.preprocessingSets) {
                    ModelStateHelper.markAllNotDoneAsFailed(pps, e);
                }
            }
            throw e;
        }
    }

    protected void runTraining(SplitDesc splitDesc) throws InterruptedException {
        for (int i = 0; i < this.getThreadCountToRun(); ++i) {
            PRNSTrainThread tat = this.createTrainThread(splitDesc);
            this.processingThreads.add(tat);
            tat.start();
        }
        PRNSTrainThread.join(this.processingThreads);
        logger.info((Object)"Train done");
    }

    @Override
    public List<Integer> getKernelPids() {
        ArrayList workingPids = Lists.newArrayList();
        for (PRNSTrainThread tat : this.processingThreads) {
            int pid = tat.getWorkingPid();
            if (pid == 0) continue;
            workingPids.add(pid);
        }
        return workingPids;
    }

    @Override
    public Map<FullModelId, ContainerUsageMetrics> getContainerUsageMetricsPerModel() {
        HashMap<FullModelId, ContainerUsageMetrics> containerUsageMetricsPerModel = new HashMap<FullModelId, ContainerUsageMetrics>();
        for (PRNSTrainThread tat : this.processingThreads) {
            ContainerUsageMetrics containerUsageMetricsForThread = tat.getContainerUsageMetrics();
            for (FullModelId fmi : tat.getCurrentFullModelIds()) {
                containerUsageMetricsPerModel.put(fmi, containerUsageMetricsForThread);
            }
        }
        return containerUsageMetricsPerModel;
    }
}

