/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.coreservices;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.coreservices.AnalysisCRUDService;
import com.dataiku.dip.analysis.coreservices.AnalysisDataService;
import com.dataiku.dip.analysis.coreservices.ClusteringService;
import com.dataiku.dip.analysis.coreservices.MLTaskCodeEnvCompatibilityComputer;
import com.dataiku.dip.analysis.coreservices.PredictionService;
import com.dataiku.dip.analysis.ml.DKUMLUtils;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.interactivemodel.InteractiveModelService;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitGenerator;
import com.dataiku.dip.analysis.ml.shared.AnalysisTrainLoggingAppender;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.ml.shared.ResultsReaderBase;
import com.dataiku.dip.analysis.ml.shared.TrainDiagnosisGenerator;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.prediction.PreTrainStatus;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.preprocessing.CatFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.NumFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TextFeaturePreprocessingParams;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.code.DesignNodeCodeEnvsService;
import com.dataiku.dip.code.PythonCodeEnvPackagesUtils;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.WorkloadType;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.dao.UsersDAO;
import com.dataiku.dip.distributed.metrics.ContainerUsageMetrics;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.futures.DSSFuturePayloadUtils;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.futures.FutureService;
import com.dataiku.dip.futures.FutureThread;
import com.dataiku.dip.io.DockerSimplePythonKernel;
import com.dataiku.dip.io.KubernetesSimplePythonKernel;
import com.dataiku.dip.io.LocalSimplePythonKernel;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.auth.NotLoggedInException;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.security.model.PublicAPIKey;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.api.auth.PublicAPIKeysService;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.backend.ModelVersionDeletedEvent;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.UsersService;
import com.dataiku.dip.shaker.server.MemScriptRunner;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.AutoCloseableLock;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NamedLock;
import com.dataiku.dip.utils.NotImplementedException;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Appender;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class MLBaseService {
    @Autowired
    protected AnalysisDataService dataService;
    @Autowired
    protected AnalysisCRUDService crudService;
    @Autowired
    protected DesignNodeCodeEnvsService designNodeEnvsService;
    @Autowired
    protected InteractiveModelService interactiveModelService;
    @Autowired
    protected TransactionService transactionService;
    @Autowired
    protected FutureService futureService;
    @Autowired
    protected PredictionService predictionService;
    @Autowired
    protected ClusteringService clusteringService;
    @Autowired
    protected UsersService usersService;
    @Autowired
    protected PublicAPIKeysService publicAPIKeysService;
    @Autowired
    protected PubSubService pubSubService;
    private AnalysisTrainLoggingAppender appender = new AnalysisTrainLoggingAppender();
    private Set<MLTaskLoc> guessing = new HashSet<MLTaskLoc>();
    private Map<MLTaskLoc, MLTaskWorkThread> working = new HashMap<MLTaskLoc, MLTaskWorkThread>();
    private Map<MLTaskLoc, MLTaskQueueStatus> taskQueues = new HashMap<MLTaskLoc, MLTaskQueueStatus>();
    private static Logger logger;

    public Set<String> getColumnNames(AnalysisCoreParams acp, AuthCtx user) throws Exception {
        MemScriptRunner.TableWithReport dataTable = this.dataService.getCachedUnfiltered_NOTRANSACTION(acp, user);
        return dataTable.table.columns.keySet();
    }

    public void stopGridSearch_NT(List<FullModelId> fullModelIds) throws Exception {
        for (FullModelId fmi : fullModelIds) {
            File stopSearchFile = fmi.getStopSearchFile();
            if (stopSearchFile.exists() || stopSearchFile.createNewFile()) continue;
            throw new Exception("Failed to create stop_search file " + String.valueOf(fmi.getModelFolder()));
        }
    }

    public void stopGridSearchSession_NT(MLTaskLoc mlTaskLoc, String sessionId) throws Exception {
        File sessionFolder = mlTaskLoc.getSessionFolder(sessionId);
        for (File ppFolder : sessionFolder.listFiles()) {
            if (!ppFolder.isDirectory() || !ppFolder.getName().matches("pp[0-9]+")) continue;
            for (File mFolder : ppFolder.listFiles()) {
                File stopSearchFile;
                if (!mFolder.isDirectory() || !mFolder.getName().matches("m[0-9]+") || (stopSearchFile = new File(mFolder, "stop_search")).exists() || stopSearchFile.createNewFile()) continue;
                throw new Exception("Failed to create stop_search file " + stopSearchFile.getAbsolutePath());
            }
        }
    }

    private boolean isSessionDone(File sessionFolder) {
        return DKUFileUtils.exists((File)sessionFolder, (String[])new String[]{"done.txt"});
    }

    private List<String> listCompletedSessions(MLTaskLoc taskLoc) {
        ArrayList<String> ret = new ArrayList<String>();
        if (!taskLoc.getSessionsFolder().isDirectory()) {
            return ret;
        }
        for (File f : DKUFileUtils.listSubfoldersOf((File)taskLoc.getSessionsFolder())) {
            if (!this.isSessionDone(f)) continue;
            ret.add(f.getName());
        }
        return ret;
    }

    private boolean isModelFinished(FullModelId fmi) {
        try {
            ModelTrainInfo mti = fmi.parseModelFile("train_info.json", ModelTrainInfo.class);
            return mti.state != null && !mti.state.isBeingTrained();
        }
        catch (Exception e) {
            logger.warn((Object)("Error trying to parse train_info.json for model " + String.valueOf(fmi)), (Throwable)e);
            return false;
        }
    }

    public List<FullModelId> listCompletedModelIds(MLTaskLoc taskLoc) {
        HashSet<FullModelId> completedModelIds = new HashSet<FullModelId>();
        for (File folder : DKUFileUtils.listSubfoldersOf((File)taskLoc.getSessionsFolder())) {
            List<FullModelId> modelIds = taskLoc.listModelIds(folder.getName());
            if (this.isSessionDone(folder)) {
                completedModelIds.addAll(modelIds);
                continue;
            }
            for (FullModelId modelId : modelIds) {
                if (!this.isModelFinished(modelId)) continue;
                completedModelIds.add(modelId);
            }
        }
        return new ArrayList<FullModelId>(completedModelIds);
    }

    public int getSessionNumber(String sessionId) {
        return Integer.parseInt(sessionId.substring(1));
    }

    public List<QueuedSession> listQueuedSessions(MLTaskLoc taskLoc) {
        ArrayList<QueuedSession> queuedSessions = new ArrayList<QueuedSession>();
        if (taskLoc.getSessionsFolder().isDirectory()) {
            for (File f : taskLoc.getSessionsFolder().listFiles()) {
                if (!this.sessionIsQueued(f)) continue;
                QueuedSession queuedSession = new QueuedSession();
                queuedSession.id = f.getName();
                try {
                    File queueInfoFile = new File(f, "queue_info.json");
                    queuedSession.metadata = (QueuedSessionMetadata)JSON.parseFile((File)queueInfoFile, QueuedSessionMetadata.class);
                }
                catch (IOException e) {
                    logger.info((Object)"Failed to read queue_info.json", (Throwable)e);
                }
                queuedSessions.add(queuedSession);
            }
            Collections.sort(queuedSessions, new Comparator<QueuedSession>(){

                @Override
                public int compare(QueuedSession s1, QueuedSession s2) {
                    return Integer.compare(MLBaseService.this.getSessionNumber(s1.id), MLBaseService.this.getSessionNumber(s2.id));
                }
            });
        }
        return queuedSessions;
    }

    public int getQueueLength(MLTaskLoc taskLoc) {
        return this.listQueuedSessions(taskLoc).size();
    }

    public void checkAndRunQueue(AnalysisCoreParams cp, MLTaskLoc loc) throws Exception {
        if (this.getQueueStatus(loc) != MLTaskQueueStatus.PAUSED) {
            MLTask.MLTaskType taskType = null;
            try (Transaction t = this.transactionService.beginRead();){
                taskType = this.crudService.getMLTask((MLTaskLoc)loc).taskType;
            }
            catch (Exception e) {
                logger.info((Object)"Failed to get ML task type", (Throwable)e);
                this.setQueueState(loc, false);
            }
            if (taskType == MLTask.MLTaskType.PREDICTION) {
                this.predictionService.trainQueue(cp, loc);
            } else if (taskType == MLTask.MLTaskType.CLUSTERING) {
                this.clusteringService.trainQueue(cp, loc);
            }
        }
    }

    public void deleteQueuedSessionsById(MLTaskLoc taskLoc, List<String> sessionList) {
        for (String sessionId : sessionList) {
            File f = taskLoc.getSessionFolder(sessionId);
            if (!this.sessionIsQueued(f)) continue;
            try {
                DKUFileUtils.delete((File)new File(f, "queue_info.json"));
            }
            catch (IOException e) {
                logger.info((Object)"Failed to delete queue_info.json", (Throwable)e);
            }
        }
    }

    public String getNextQueuedSessionId(MLTaskLoc loc) {
        String nextSessionId = null;
        File[] sessionsList = loc.getSessionsFolder().listFiles();
        if (sessionsList != null && sessionsList.length != 0) {
            Arrays.sort(sessionsList, Comparator.comparingLong(File::lastModified));
            for (File f : sessionsList) {
                if (!this.sessionIsQueued(f)) continue;
                nextSessionId = f.getName();
                break;
            }
        }
        return nextSessionId;
    }

    private boolean sessionIsQueued(File f) {
        return f.isDirectory() && new File(f, "queue_info.json").exists() && !new File(f, "done.txt").exists();
    }

    public DSSAuthCtx getQueuedSessionUser(String login, MLTaskLoc loc) throws Exception {
        UsersDAO.User user = this.usersService.getInternalUserOrNullUnsafe(login);
        try {
            PublicAPIKey apiKey;
            if (user != null) {
                return DSSAuthCtx.forUserLogin(user);
            }
            if (login.startsWith("api:") && (apiKey = this.publicAPIKeysService.getKeyById(login.substring(4))) != null) {
                return DSSAuthCtx.forAPIKey(apiKey);
            }
        }
        catch (NotLoggedInException e) {
            logger.warn((Object)String.format("Failed to retrieve user %s for ML Task %s", login, loc), (Throwable)e);
            throw ErrorContext.iaef((String)"Cannot run ML Task %s as user %s", (Object)loc, (Object[])new Object[]{login, ExceptionUtils.getMessageWithCauses((Throwable)e)});
        }
        throw ErrorContext.iaef((String)"Cannot run ML Task %s as user %s", (Object)loc, (Object[])new Object[]{login});
    }

    void checkVisualMlRuntime(AuthCtx user, PreTrainStatus pts, String projectKey, MLTask task) {
        if (task instanceof PredictionMLTask.DeepHubPredictionMLTask) {
            throw new IllegalArgumentException("Visual ML Env Selection check not supported for Deep Hub, because it uses an internal code env.");
        }
        String PLEASE_SELECT_COMPATIBLE = "Please select a compatible code-env in 'Runtime environment'.";
        MLTaskCodeEnvCompatibilityComputer incompatibilityComputer = new MLTaskCodeEnvCompatibilityComputer(task);
        try {
            String envName = new CodeEnvSelector().selectForDoctor(projectKey, task.envSelection);
            this.checkSentenceEmbeddingEnvCompatibility(pts, task, envName);
            this.checkVisualMLContainerSelection(user, pts, envName, projectKey, task);
            if (envName == null && task.envSelection.envMode == CodeEnvSelection.EnvMode.EXPLICIT_ENV) {
                pts.messages.add(InfoMessage.warning((String)"No runtime code-env selected", (String)"Please select a compatible code-env in 'Runtime environment'."));
            } else {
                PythonCodeEnvPackagesUtils.PythonEnvPackages envPackages = null;
                PythonCodeEnvPackagesUtils.CodeEnvVisualMLCompat envCompat = null;
                if (envName != null) {
                    CodeEnvModel.DesignUIPythonEnv env = this.designNodeEnvsService.getPythonEnvForUI(CodeEnvModel.EnvLang.PYTHON, envName, false);
                    envPackages = PythonCodeEnvPackagesUtils.getEnvPackages(env);
                    envCompat = new PythonCodeEnvPackagesUtils.CodeEnvVisualMLCompat(envPackages);
                } else {
                    envCompat = PythonCodeEnvPackagesUtils.CodeEnvVisualMLCompat.builtinEnvCompatibility();
                }
                Set<String> envIncompatibilityReasons = incompatibilityComputer.getIncompatibilityReasons_NT(envCompat);
                if (envIncompatibilityReasons != null && !envIncompatibilityReasons.isEmpty()) {
                    ArrayList<String> fullMessage = new ArrayList<String>(envIncompatibilityReasons);
                    fullMessage.add("Please select a compatible code-env in 'Runtime environment'.");
                    String badEnvMessageStr = "Runtime code-env seems incompatible with " + incompatibilityComputer.getMLTaskDescriptionForCompatibility(envCompat);
                    pts.messages.add(InfoMessage.warning((String)badEnvMessageStr, (String)String.join((CharSequence)".\n", fullMessage)));
                }
            }
        }
        catch (DKUSecurityException | IOException e) {
            logger.warn((Object)"Failed to fetch information about mltask code-env", e);
            pts.messages.add(InfoMessage.warning((String)"Failed to fetch information about selected env", (String)"Please select a compatible code-env in 'Runtime environment'."));
        }
    }

    public void checkVisualMLContainerSelection(AuthCtx user, PreTrainStatus pts, String envName, String projectKey, MLTask task) throws IOException, DKUSecurityException {
        if (envName == null) {
            return;
        }
        ContainerExecRuntimeConfig containerConfig = new ContainerExecConfigSelector().selectForML_autoTXN(user, projectKey, task.containerSelection, task.backendType);
        if (containerConfig == null) {
            return;
        }
        List<ContainerExecRuntimeConfig> containerExecConfigs = ApplicationConfigurator.getGeneralSettings().containerSettings.listConfigsForWorkloadType(WorkloadType.USER_CODE);
        Set<String> imagesBuiltForContainerConf = this.designNodeEnvsService.builtForContainerConfs(envName, CodeEnvModel.EnvLang.PYTHON, containerExecConfigs);
        if (!imagesBuiltForContainerConf.contains(containerConfig.name)) {
            pts.messages.add(InfoMessage.warning((String)"Runtime code-env not built for this containerized configuration.", (String)String.format("Build or ask your administrator to rebuild your selected code env (%s) for your containerized configuration (%s)", envName, containerConfig.name)));
        }
    }

    private void checkSentenceEmbeddingEnvCompatibility(PreTrainStatus pts, MLTask task, String envName) throws IOException {
        Map<String, String> codeEnvSentenceEmbeddedFeaturesAndModels = task.getPreprocessingParams().codeEnvSentenceEmbeddedFeaturesAndModels();
        if (codeEnvSentenceEmbeddedFeaturesAndModels.isEmpty()) {
            return;
        }
        CodeEnvModel.CodeEnvResourcesModelsMeta modelsMeta = this.designNodeEnvsService.getResourcesModelsMeta(envName);
        for (Map.Entry<String, String> sentenceEmbeddedFeature : codeEnvSentenceEmbeddedFeaturesAndModels.entrySet()) {
            String sentenceEmbeddingFeatureName = sentenceEmbeddedFeature.getKey();
            String sentenceEmbeddingModelName = sentenceEmbeddedFeature.getValue();
            if (StringUtils.isBlank((String)sentenceEmbeddingModelName)) {
                pts.messages.add(InfoMessage.fatal((String)("Text embedding requires a model to be selected for feature: " + sentenceEmbeddingFeatureName)));
                continue;
            }
            if (!modelsMeta.sentence_transformers.containsKey(sentenceEmbeddingModelName)) {
                logger.info((Object)("Skipping validation for model not in code env resources: " + sentenceEmbeddingModelName));
                continue;
            }
            if (modelsMeta.sentence_transformers.get((Object)sentenceEmbeddingModelName).compat) continue;
            String selectedModelType = modelsMeta.sentence_transformers.get((Object)sentenceEmbeddingModelName).type;
            if (selectedModelType == null) {
                selectedModelType = "";
            }
            pts.messages.add(InfoMessage.warning((String)"Text embedding model seems incompatible", (String)String.format("Model type %s used for feature %s seems incompatible with text embedding.", selectedModelType, sentenceEmbeddingFeatureName)));
        }
    }

    public void forgetFeatureSelection(MLTaskLoc loc) throws IOException {
        logger.info((Object)("Forgetting feature selection for " + String.valueOf(loc)));
        DKUFileUtils.deleteDirectory((File)loc.getSelectionFolder());
    }

    public void deleteModels(List<String> fullModelIds) {
        logger.info((Object)"Deleting some models");
        HashSet<MLTaskLoc> taskLocs = new HashSet<MLTaskLoc>();
        for (String fmiStr : fullModelIds) {
            try {
                FullModelId fmi = FullModelId.parse(fmiStr);
                logger.info((Object)("Deleting " + fmiStr));
                File modelFolder = fmi.getModelFolder();
                FilesystemACLUtils.removeACLRestrictiveMask(modelFolder);
                DKUFileUtils.deleteDirectory((File)modelFolder);
                if (fmi.getType() == FullModelId.Type.ANALYSIS) {
                    taskLocs.add(fmi.getTaskLoc());
                }
                this.pubSubService.publish((DSSEvent)new ModelVersionDeletedEvent(fmi.toString()));
            }
            catch (Exception e) {
                logger.warn((Object)"failed to delete", (Throwable)e);
            }
        }
        for (MLTaskLoc taskLoc : taskLocs) {
            logger.info((Object)("Task loc " + String.valueOf(taskLoc) + " was changed, cleaning up"));
            try {
                this.keepOnlyUsedSplits(taskLoc);
                this.keepOnlyUsedPreprocessingSets(taskLoc);
            }
            catch (IOException e) {
                logger.warn((Object)("Failed to cleanup ML Task " + String.valueOf(taskLoc)), (Throwable)e);
            }
        }
    }

    private void keepOnlyUsedPreprocessingSets(MLTaskLoc taskLoc) throws IOException {
        for (String sessionId : this.listCompletedSessions(taskLoc)) {
            logger.debug((Object)("Cleanup preprocessing sets of done session " + sessionId));
            File sessionFolder = taskLoc.getSessionFolder(sessionId);
            assert (sessionFolder.isDirectory());
            for (File preprocFolder : sessionFolder.listFiles()) {
                if (!preprocFolder.isDirectory()) continue;
                int models = 0;
                for (File modelFolder : preprocFolder.listFiles()) {
                    if (!modelFolder.isDirectory() || !modelFolder.getName().matches("m\\d+")) continue;
                    ++models;
                }
                if (models != 0) continue;
                logger.info((Object)("No model left in preprocessing set " + preprocFolder.getName() + ", removing it"));
                DKUFileUtils.forceDelete((File)preprocFolder);
            }
        }
    }

    private List<FullModelId> listAllModelsIdsOfTask(MLTaskLoc taskLoc) {
        List<FullModelId> modelIdsRemainingInThisTask = this.listCompletedModelIds(taskLoc);
        HashSet<String> modelsBeingTrained = new HashSet<String>();
        MLTaskWorkThread thread = this.getWorkingThread(taskLoc);
        if (thread != null) {
            for (FullModelId fmi : thread.modelsBeingTrained()) {
                modelsBeingTrained.add(fmi.toString());
                modelIdsRemainingInThisTask.add(fmi);
            }
        }
        for (FullModelId fmi : this.listCompletedModelIds(taskLoc)) {
            if (modelsBeingTrained.contains(fmi.toString())) continue;
            modelIdsRemainingInThisTask.add(fmi);
        }
        return modelIdsRemainingInThisTask;
    }

    public FullModelId getLatestModelId(MLTaskLoc taskLoc) throws IOException {
        MLTask.MLTaskType taskType;
        FullModelId fmi = null;
        try (Transaction t = this.transactionService.beginRead();){
            taskType = this.crudService.getMLTask((MLTaskLoc)taskLoc).taskType;
        }
        if (taskType == MLTask.MLTaskType.PREDICTION) {
            PredictionMLTask pmlTask;
            try (Transaction t = this.transactionService.beginRead();){
                pmlTask = this.crudService.getPMLTask(taskLoc);
            }
            List<FullModelId> fullModelIdList = this.predictionService.listTaskModelIds(taskLoc);
            fmi = this.predictionService.getLatestModelId(pmlTask, fullModelIdList);
        } else if (taskType == MLTask.MLTaskType.CLUSTERING) {
            List<FullModelId> fullModelIdList = this.clusteringService.listTaskModelIds(taskLoc);
            fmi = this.clusteringService.getLatestModelId(fullModelIdList);
        } else {
            throw new IllegalArgumentException("Unsupported task type: " + String.valueOf((Object)taskType));
        }
        return fmi;
    }

    public void keepOnlyUsedSplits(MLTaskLoc taskLoc) throws IOException {
        logger.info((Object)("Keeping only used splits for task loc " + String.valueOf(taskLoc)));
        List<FullModelId> modelIdsRemainingInThisTask = this.listAllModelsIdsOfTask(taskLoc);
        HashSet<String> requiredSplitIds = new HashSet<String>();
        for (FullModelId fmi : modelIdsRemainingInThisTask) {
            SplitDesc splitDesc = ResultsReaderBase.readSplitDesc(fmi);
            if (splitDesc == null) continue;
            requiredSplitIds.add(splitDesc.instanceId);
        }
        logger.info((Object)("Keeping only the following split descs: " + JSON.json(requiredSplitIds)));
        logger.info((Object)("Plus the ones being built: " + String.valueOf(SplitGenerator.getSplitsBeingComputed(taskLoc))));
        requiredSplitIds.addAll(SplitGenerator.getSplitsBeingComputed(taskLoc));
        taskLoc.keepOnlySplits(requiredSplitIds);
    }

    public void deleteAllSplits(MLTaskLoc taskLoc) throws IOException {
        logger.info((Object)("Deleting all splits for task loc " + String.valueOf(taskLoc)));
        if (taskLoc.getSplitsFolder().isDirectory()) {
            for (File f : taskLoc.getSplitsFolder().listFiles()) {
                DKUFileUtils.forceDelete((File)f);
            }
        }
    }

    public static FuturePayload buildBaseFuturePayload(MLTaskLoc loc, String displayName, String part, String modelType) {
        FuturePayload fp = new FuturePayload();
        fp.action = part;
        String taskName = null;
        try (Transaction t = ((TransactionService)SpringUtils.getBean(TransactionService.class)).beginRead();){
            taskName = ((AnalysisCRUDService)SpringUtils.getBean(AnalysisCRUDService.class)).getMLTask((MLTaskLoc)loc).name;
        }
        catch (Throwable e) {
            logger.info((Object)"Failed to get ML task", e);
        }
        fp.targets.add(DSSFuturePayloadUtils.forMLTaskModel(loc, modelType, taskName).withPart(part));
        fp.displayName = displayName;
        return fp;
    }

    public void startLogAppender() {
        DKUtils.startLogAppender((Appender)this.appender);
        LocalSimplePythonKernel.loggerNotAdditive.addAppender((Appender)this.appender);
        DockerSimplePythonKernel.loggerNotAdditive.addAppender((Appender)this.appender);
        KubernetesSimplePythonKernel.loggerNotAdditive.addAppender((Appender)this.appender);
    }

    public synchronized void addGuessing(MLTaskLoc loc) {
        this.guessing.add(loc);
    }

    public synchronized void removeGuessing(MLTaskLoc loc) {
        this.guessing.remove(loc);
    }

    public synchronized boolean isGuessing(MLTaskLoc loc) {
        return this.guessing.contains(loc);
    }

    public synchronized boolean isTraining(MLTaskLoc loc) {
        return this.working.containsKey(loc);
    }

    public synchronized void failIfWorking(MLTaskLoc loc, boolean forQueueing) {
        if (this.guessing.contains(loc)) {
            throw new IllegalArgumentException("Already guessing");
        }
        if (this.working.containsKey(loc) && !forQueueing) {
            throw new IllegalArgumentException("Already training");
        }
    }

    public synchronized void setWorking(MLTaskLoc loc, MLTaskWorkThread thread) {
        this.working.put(loc, thread);
    }

    public synchronized void setNotWorking(MLTaskLoc loc) {
        this.working.remove(loc);
    }

    public synchronized MLTaskWorkThread getWorkingThread(MLTaskLoc loc) {
        return this.working.get(loc);
    }

    public String getNextSessionIdAndCreateSessionFolder(MLTaskLoc loc) throws InterruptedException, IOException {
        String newSessionId;
        String lockName = "analysis.ml.create-session." + loc.toString();
        try (AutoCloseableLock lock = NamedLock.acquireInterruptibly((String)lockName);){
            newSessionId = DKUtils.nextFileInSequence((File)loc.getSessionsFolder(), (String)"s").getName();
            File sessionFolder = MLPaths.sessionFolder(loc, newSessionId);
            if (!sessionFolder.mkdirs()) {
                throw new IOException("Could not create session folder for session '" + newSessionId + "', for task '" + String.valueOf(loc) + "', may already exist.");
            }
        }
        return newSessionId;
    }

    public synchronized void setQueueState(MLTaskLoc loc, boolean running) {
        this.taskQueues.put(loc, running ? MLTaskQueueStatus.RUNNING : MLTaskQueueStatus.PAUSED);
    }

    public synchronized Map<MLTaskLoc, MLTaskQueueStatus> getQueues() {
        return this.taskQueues;
    }

    public synchronized MLTaskQueueStatus getQueueStatus(MLTaskLoc loc) {
        return this.taskQueues.getOrDefault(loc, MLTaskQueueStatus.UNKNOWN);
    }

    public void abort_NT(MLTaskLoc loc, boolean pauseQueue) throws Exception {
        try {
            MLTaskWorkThread thread = this.getWorkingThread(loc);
            if (thread != null) {
                if (pauseQueue) {
                    this.setQueueState(loc, false);
                }
                this.futureService.abort(thread.jobId);
                for (FullModelId fmi : thread.modelsBeingTrained()) {
                    if (fmi.isPartitionedBaseModel()) continue;
                    final File mtiFile = fmi.getModelInfoFile();
                    ModelStateHelper.updateModelTrainInfoAtomically(mtiFile, new Runnable(){

                        @Override
                        public void run() {
                            try {
                                ModelTrainInfo mti = (ModelTrainInfo)JSON.parseFile((File)mtiFile, ModelTrainInfo.class);
                                if (!ModelTrainInfo.ModelTrainState.DONE.equals((Object)mti.state)) {
                                    mti.state = ModelTrainInfo.ModelTrainState.ABORTED;
                                    JSON.prettyToFile((Object)mti, (File)mtiFile);
                                }
                            }
                            catch (IOException e) {
                                logger.warn((Object)"Failed to record aborted state", (Throwable)e);
                            }
                        }
                    });
                }
            }
        }
        catch (Exception e) {
            logger.warn((Object)"Failed to abort", (Throwable)e);
        }
    }

    public void abort_NT(MLTaskLoc loc, List<FullModelId> fullModelIdList, boolean pauseQueue) throws Exception {
        try {
            MLTaskWorkThread thread = this.getWorkingThread(loc);
            if (thread != null) {
                thread.abort(fullModelIdList);
                if (pauseQueue) {
                    this.setQueueState(loc, false);
                }
            }
            for (FullModelId fmi : fullModelIdList) {
                if (fmi.isPartitionedBaseModel()) continue;
                final File mtiFile = fmi.getModelInfoFile();
                ModelStateHelper.updateModelTrainInfoAtomically(mtiFile, new Runnable(){

                    @Override
                    public void run() {
                        try {
                            ModelTrainInfo mti = (ModelTrainInfo)JSON.parseFile((File)mtiFile, ModelTrainInfo.class);
                            mti.state = ModelTrainInfo.ModelTrainState.ABORTED;
                            JSON.prettyToFile((Object)mti, (File)mtiFile);
                        }
                        catch (IOException e) {
                            logger.warn((Object)"Failed to record aborted state", (Throwable)e);
                        }
                    }
                });
            }
        }
        catch (Exception e) {
            logger.warn((Object)"Failed to partial abort", (Throwable)e);
        }
    }

    public void generateTrainDiagnosis(AuthCtx authCtx, FullModelId fmi, boolean includeTrainingData, OutputStream outputStream) throws IOException, InterruptedException {
        TrainDiagnosisGenerator diagnosisBuilder = new TrainDiagnosisGenerator(authCtx);
        diagnosisBuilder.generateDiagnosis(fmi, includeTrainingData, outputStream);
    }

    public void deleteMLTask(MLTaskLoc loc) throws Exception {
        if (this.isGuessing(loc) || this.isTraining(loc)) {
            throw new IllegalArgumentException("Can't delete while guessing or training");
        }
        DKUFileUtils.deleteDirectory((File)loc.getDataFolder());
        this.crudService.deleteMLTask(loc.analysisProjectKey, loc.analysisId, loc.mlTaskId);
    }

    public void copyFeaturesHandling(MLTask taskFrom, MLTask taskTo) {
        Map<String, FeaturePreprocessingParams> featureParamsFrom = taskFrom.getPreprocessingParams().per_feature;
        Map<String, FeaturePreprocessingParams> featureParamsTo = taskTo.getPreprocessingParams().per_feature;
        for (Map.Entry<String, FeaturePreprocessingParams> entry : featureParamsTo.entrySet()) {
            String featureName = entry.getKey();
            FeaturePreprocessingParams.Role roleTo = entry.getValue().role;
            FeaturePreprocessingParams featureParamFrom = (FeaturePreprocessingParams)JSON.deepCopy((Object)featureParamsFrom.get(featureName));
            if (featureParamFrom == null || this.skipFeatureHandlingCopy(taskTo.backendType, taskTo.taskType, featureParamFrom, roleTo)) continue;
            featureParamFrom.role = this.getPostCopyRole(taskFrom.taskType, taskTo.taskType, featureParamFrom.role, roleTo);
            featureParamsTo.put(featureName, featureParamFrom);
        }
    }

    public void saveModelUserMeta(FullModelId fmi, ModelUserMeta newMeta) throws IOException {
        ModelUserMeta oldMeta = fmi.getUserMeta();
        fmi.saveUserMeta(newMeta);
        if (this.shouldInvalidateModel(oldMeta, newMeta)) {
            this.interactiveModelService.invalidateModel(fmi);
        }
    }

    private boolean shouldInvalidateModel(ModelUserMeta oldMeta, ModelUserMeta newMeta) {
        return oldMeta.activeClassifierThreshold != newMeta.activeClassifierThreshold;
    }

    public FeaturePreprocessingParams.Role getPostCopyRole(MLTask.MLTaskType taskTypeFrom, MLTask.MLTaskType taskTypeTo, FeaturePreprocessingParams.Role roleFrom, FeaturePreprocessingParams.Role roleTo) {
        if ((roleTo == FeaturePreprocessingParams.Role.INPUT || roleTo == FeaturePreprocessingParams.Role.REJECT) && roleFrom == FeaturePreprocessingParams.Role.PROFILING && taskTypeTo == MLTask.MLTaskType.PREDICTION) {
            return FeaturePreprocessingParams.Role.REJECT;
        }
        if (roleTo == FeaturePreprocessingParams.Role.PROFILING && roleFrom == FeaturePreprocessingParams.Role.REJECT && taskTypeFrom == MLTask.MLTaskType.PREDICTION) {
            return FeaturePreprocessingParams.Role.PROFILING;
        }
        return roleFrom;
    }

    public boolean skipFeatureHandlingCopy(MLTask.BackendType backendTypeTo, MLTask.MLTaskType taskTypeTo, FeaturePreprocessingParams paramsFrom, FeaturePreprocessingParams.Role roleTo) {
        if (paramsFrom.role == FeaturePreprocessingParams.Role.TARGET || roleTo == FeaturePreprocessingParams.Role.TARGET || paramsFrom.role == FeaturePreprocessingParams.Role.WEIGHT || roleTo == FeaturePreprocessingParams.Role.WEIGHT || paramsFrom.role == FeaturePreprocessingParams.Role.TIME || roleTo == FeaturePreprocessingParams.Role.TIME || paramsFrom.role == FeaturePreprocessingParams.Role.TIMESERIES_IDENTIFIER || roleTo == FeaturePreprocessingParams.Role.TIMESERIES_IDENTIFIER) {
            return true;
        }
        switch (paramsFrom.type) {
            case CATEGORY: {
                CatFeaturePreprocessingParams.CategoryHandlingMethod handlingMethod = ((CatFeaturePreprocessingParams)paramsFrom).category_handling;
                switch (backendTypeTo) {
                    case H2O: {
                        return handlingMethod != CatFeaturePreprocessingParams.CategoryHandlingMethod.DUMMIFY && handlingMethod != CatFeaturePreprocessingParams.CategoryHandlingMethod.NONE;
                    }
                    case MLLIB: {
                        return handlingMethod != CatFeaturePreprocessingParams.CategoryHandlingMethod.DUMMIFY && handlingMethod != CatFeaturePreprocessingParams.CategoryHandlingMethod.FLAG_PRESENCE;
                    }
                    case PY_MEMORY: 
                    case KERAS: {
                        return handlingMethod == CatFeaturePreprocessingParams.CategoryHandlingMethod.NONE || handlingMethod == CatFeaturePreprocessingParams.CategoryHandlingMethod.IMPACT && taskTypeTo == MLTask.MLTaskType.CLUSTERING;
                    }
                }
                logger.warn((Object)("Unsupported backend type " + String.valueOf((Object)paramsFrom.type)));
                return true;
            }
            case TEXT: {
                return backendTypeTo.isSparkBased() && ((TextFeaturePreprocessingParams)paramsFrom).text_handling != TextFeaturePreprocessingParams.TextHandlingMethod.TOKENIZE_COUNTS;
            }
            case IMAGE: {
                return backendTypeTo != MLTask.BackendType.KERAS;
            }
            case NUMERIC: {
                return backendTypeTo.isSparkBased() && ((NumFeaturePreprocessingParams)paramsFrom).numerical_handling != NumFeaturePreprocessingParams.NumericalHandlingMethod.REGULAR;
            }
            case VECTOR: {
                return false;
            }
        }
        logger.warn((Object)("Unsupported feature type " + String.valueOf((Object)paramsFrom.type)));
        return true;
    }

    static {
        DKUMLUtils.loadClasses();
        logger = Logger.getLogger((String)"dku.analysis.prediction");
    }

    public static class QueuedSession {
        String id;
        QueuedSessionMetadata metadata;
    }

    public static class QueuedSessionMetadata {
        public String userSessionName;
        public String userSessionDescription;
        public Boolean forceSplitRefresh;
        public String sessionOwner;
    }

    public static enum MLTaskQueueStatus {
        RUNNING,
        PAUSED,
        UNKNOWN;

    }

    public static abstract class MLTaskWorkThread
    extends FutureThread<Void> {
        private final FuturePayload futurePayload;
        public static final InheritableThreadLocal<MLTaskContext> mlTaskContext = new InheritableThreadLocal();
        protected final MLTaskLoc loc;
        protected final AnalysisCoreParams cp;

        public MLTaskWorkThread(DSSAuthCtx authCtx, MLTaskLoc loc, AnalysisCoreParams cp) {
            super(authCtx);
            this.loc = loc;
            this.cp = cp;
            this.futurePayload = this.buildFuturePayload(loc, cp);
        }

        public abstract FuturePayload buildFuturePayload(MLTaskLoc var1, AnalysisCoreParams var2);

        public abstract void init() throws Exception;

        public abstract List<FullModelId> modelsBeingTrained();

        public Map<FullModelId, ContainerUsageMetrics> getContainerUsageMetricsPerModel() {
            return new HashMap<FullModelId, ContainerUsageMetrics>();
        }

        public void abort(List<FullModelId> fullModelIdSet) throws IOException {
            throw new NotImplementedException("Partial abort not implemented for " + ((Object)((Object)this)).getClass().getSimpleName());
        }

        public void postRunCleanup() {
            try {
                ((MLBaseService)SpringUtils.getBean(MLBaseService.class)).checkAndRunQueue(this.cp, this.loc);
            }
            catch (Exception e) {
                logger.info((Object)"Failed to resume queue and start next session", (Throwable)e);
                ((MLBaseService)SpringUtils.getBean(MLBaseService.class)).setQueueState(this.loc, false);
            }
        }

        public double getDangerosity() {
            return 0.0;
        }

        public abstract List<Integer> getPids();

        public FuturePayload getPayload() {
            this.futurePayload.withExtra("pids", this.getPids());
            this.futurePayload.withExtra("mlTaskId", (Object)this.loc.mlTaskId);
            return this.futurePayload;
        }

        public static class MLTaskContext {
            public List<FullModelId> fullModelIds;

            public MLTaskContext(List<FullModelId> fullModelIds) {
                this.fullModelIds = fullModelIds;
            }

            public List<File> getSessionLogFiles() {
                ArrayList<File> ppLogFiles = new ArrayList<File>();
                for (FullModelId fmi : this.fullModelIds) {
                    ppLogFiles.add(new File(fmi.getPreprocessingFolder(), "train.log"));
                }
                return ppLogFiles;
            }
        }
    }
}

