/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.shared;

import com.dataiku.common.server.APIError;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.EvaluationLabelsHelper;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.clustering.HeatMap;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.core.PostTrainModelingParams;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.ActualModelParameters;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.analysis.model.prediction.DeepHubPreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubPredictionModelPerf;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.pivot.frontend.color.Color;
import com.dataiku.dip.pivot.frontend.color.DiscretePalette;
import com.dataiku.dip.pivot.frontend.color.PaletteFactory;
import com.dataiku.dip.utils.AutoCloseableLock;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NamedLock;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class ModelStateHelper {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.ml");

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void updateModelTrainInfoAtomically(File mtiFile, Runnable runnable) {
        String mtiFilePath = mtiFile.getAbsolutePath();
        String lockName = "ml.traininfo." + mtiFilePath;
        logger.info((Object)("Locking model train info file " + mtiFilePath));
        try (AutoCloseableLock lock = NamedLock.acquire((String)lockName);){
            runnable.run();
        }
        finally {
            logger.info((Object)("Unlocked model train info file " + mtiFilePath));
        }
    }

    public static void markAllNotDoneAsFailed(WorkSet.PreprocessingSet pps, Throwable t) {
        ModelStateHelper.markAllNotFinalAsState(pps, ModelTrainInfo.ModelTrainState.FAILED, t);
    }

    public static void markAllNotFinalAsState(WorkSet.PreprocessingSet pps, final ModelTrainInfo.ModelTrainState state, final Throwable t) {
        for (WorkSet.ModelingSet ms : pps.modelingSets) {
            final File mtiFile = ms.fullId.getModelInfoFile();
            if (mtiFile.exists()) {
                ModelStateHelper.updateModelTrainInfoAtomically(mtiFile, new Runnable(){

                    @Override
                    public void run() {
                        ModelTrainInfo mti = null;
                        try {
                            mti = (ModelTrainInfo)JSON.parseFile((File)mtiFile, ModelTrainInfo.class);
                            if (mti.state == ModelTrainInfo.ModelTrainState.PENDING || mti.state == ModelTrainInfo.ModelTrainState.RUNNING) {
                                mti.state = state;
                                if (t != null) {
                                    mti.failure = new APIError(t, true);
                                }
                            }
                            JSON.prettyToFile((Object)mti, (File)mtiFile);
                        }
                        catch (IOException e2) {
                            logger.warn((Object)"Failed to update status", (Throwable)e2);
                        }
                    }
                });
                continue;
            }
            logger.warn((Object)("MTI file " + String.valueOf(mtiFile) + " does not exist"));
        }
    }

    public static void setModelState(FullModelId fmi, final ModelTrainInfo.ModelTrainState state) {
        final File mtiFile = fmi.getModelInfoFile();
        if (mtiFile.exists()) {
            ModelStateHelper.updateModelTrainInfoAtomically(mtiFile, new Runnable(){

                @Override
                public void run() {
                    ModelTrainInfo mti = null;
                    try {
                        mti = (ModelTrainInfo)JSON.parseFile((File)mtiFile, ModelTrainInfo.class);
                        mti.state = state;
                        JSON.prettyToFile((Object)mti, (File)mtiFile);
                    }
                    catch (IOException e2) {
                        logger.warn((Object)"Failed to update status", (Throwable)e2);
                    }
                }
            });
        } else {
            logger.warn((Object)("MTI file " + String.valueOf(mtiFile) + " does not exist"));
        }
    }

    public static void updatePredictionTrainInfoAndUserMeta(PredictionMLTask.PredictionType type, WorkSet.PreprocessingSet pps) throws IOException {
        for (WorkSet.ModelingSet ms : pps.modelingSets) {
            ModelTrainInfo mti = ms.fullId.parseModelFile("train_info.json", ModelTrainInfo.class);
            if (mti.state != ModelTrainInfo.ModelTrainState.DONE) continue;
            PreTrainModelingParams preTrain = ms.fullId.parseModelFile("rmodeling_params.json", PreTrainModelingParams.class);
            ActualModelParameters amp = ms.fullId.parseModelFile("actual_params.json", ActualModelParameters.class);
            logger.info((Object)("preTrain=" + JSON.prettyLog((Object)preTrain)));
            logger.info((Object)("amp=" + JSON.prettyLog((Object)amp)));
            assert (amp != null);
            assert (amp.resolved != null);
            assert (ms.userMeta != null);
            PostTrainModelingParams resolved = amp.resolved;
            if (preTrain instanceof PreTrainPredictionModelingParams) {
                mti.postSearchDescription = resolved.generatePostSearchDescription(mti.preSearchDescription, preTrain);
                if (type == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION) {
                    BinaryClassificationModelPerf perf = ms.fullId.parseModelFile("perf.json", BinaryClassificationModelPerf.class);
                    ms.userMeta.activeClassifierThreshold = perf.usedThreshold;
                }
            } else if (preTrain instanceof DeepHubPreTrainModelingParams && type == PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION) {
                ms.fullId.getDeepHubPredictionPerf().ifPresent(p -> {
                    ms.userMeta.activeClassifierThreshold = ((DeepHubPredictionModelPerf.DeepHubObjectDetectionPredictionModelPerf)p).optimalConfidenceScoreThreshold;
                });
            }
            ms.userMeta.labels = EvaluationLabelsHelper.setTrainTime(ms.userMeta.labels, mti.endTime);
            JSON.prettyToFile((Object)mti, (File)new File(ms.run_folder, "train_info.json"));
            JSON.prettyToFile((Object)ms.userMeta, (File)new File(ms.run_folder, "user_meta.json"));
        }
    }

    public static void addModelDateLabel(FullModelId fmi, ModelTrainInfo mti) throws IOException {
        ModelUserMeta userMeta = fmi.getUserMeta();
        userMeta.labels = EvaluationLabelsHelper.setTrainTime(userMeta.labels, mti.endTime);
        fmi.saveUserMeta(userMeta);
    }

    public static void updateClusteringTrainInfoAndUserMeta(WorkSet.PreprocessingSet pps) throws IOException {
        for (WorkSet.ModelingSet ms : pps.modelingSets) {
            ModelTrainInfo mti = ms.fullId.parseModelFile("train_info.json", ModelTrainInfo.class);
            if (mti.state != ModelTrainInfo.ModelTrainState.DONE) continue;
            ActualModelParameters amp = ms.fullId.parseModelFile("actual_params.json", ActualModelParameters.class);
            assert (amp != null);
            assert (amp.resolved != null);
            assert (ms.userMeta != null);
            long trainingTimeSeconds = (mti.trainingTime + mti.preprocessingTime) / 1000L;
            String trainedIn = trainingTimeSeconds < 120L ? trainingTimeSeconds + " seconds" : trainingTimeSeconds / 60L + " minutes";
            SplitDesc.SplitRef sr = ms.fullId.parseSessionFile("split_ref.json", SplitDesc.SplitRef.class);
            File splitFile = new File(ms.fullId.getTaskLoc().getSplitsFolder(), sr.splitInstanceId + ".json");
            SplitDesc splitDesc = (SplitDesc)JSON.parseFile((File)splitFile, SplitDesc.class);
            ms.userMeta.description = ms.userMeta.description + "\nTrained in " + trainedIn + " on " + (splitDesc.fullRows == 0L ? mti.fullRows : splitDesc.fullRows) + " records\n";
            ms.userMeta.labels = EvaluationLabelsHelper.setTrainTime(ms.userMeta.labels, mti.endTime);
            ms.userMeta.clusterMetas = ModelStateHelper.createClusterColors(ms.run_folder);
            JSON.prettyToFile((Object)ms.userMeta, (File)new File(ms.run_folder, "user_meta.json"));
        }
    }

    private static Map<String, ModelUserMeta.ClusterMeta> createClusterColors(String runFolder) throws IOException {
        HashMap<String, ModelUserMeta.ClusterMeta> clusterMetas = new HashMap<String, ModelUserMeta.ClusterMeta>();
        HeatMap heatMap = (HeatMap)JSON.parseFile((File)new File(runFolder, "heatmap.json"), HeatMap.class);
        DiscretePalette palette = PaletteFactory.buildIndexedPalette("categorical");
        int index = 0;
        for (String name : heatMap.cluster_labels) {
            ModelUserMeta.ClusterMeta clusterMeta = new ModelUserMeta.ClusterMeta();
            Color color = palette.apply(index);
            clusterMeta.color = color.toHtml();
            clusterMeta.name = name;
            clusterMetas.put(name, clusterMeta);
            ++index;
        }
        return clusterMetas;
    }
}

