/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.common.server.APIError;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.PartitionedExtractService;
import com.dataiku.dip.analysis.ml.prediction.StratifiedMetricsAggregator;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.prediction.ActualModelParameters;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelDetails;
import com.dataiku.dip.partitioning.StratifiedModelUtils;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PartitionedModelsService {
    @Autowired
    private PartitionedExtractService partitionedExtractService;
    private DKULogger logger = DKULogger.getLogger((String)"dku.ml.partitions");

    public void gatherBaseModelInfo(FullModelId globalModelId) throws IOException {
        if (this.hasModelCompleted(globalModelId)) {
            return;
        }
        RuntimeException error = null;
        try {
            this.fillBaseModelFolder(globalModelId);
        }
        catch (Exception e) {
            this.logger.warn((Object)("Failed to fill base model folder of model '" + String.valueOf(globalModelId) + "'"), (Throwable)e);
            error = new RuntimeException("Failed to gather global model information", e);
        }
        this.partitionedExtractService.dumpAndClear(globalModelId);
        ModelTrainInfo mti = this.setTrainInfo(globalModelId, error);
        ModelStateHelper.addModelDateLabel(globalModelId, mti);
    }

    private boolean hasModelCompleted(FullModelId globalModelId) throws IOException {
        if (!globalModelId.getModelFile("train_info.json").exists()) {
            return false;
        }
        ModelTrainInfo mti = globalModelId.parseModelFile("train_info.json", ModelTrainInfo.class);
        return mti.state != ModelTrainInfo.ModelTrainState.RUNNING && mti.state != ModelTrainInfo.ModelTrainState.PENDING;
    }

    private void fillBaseModelFolder(FullModelId globalModelId) throws IOException {
        ResolvedPredictionCoreParams rpcp;
        List<ModelDetailsBase> doneModelDetailsList = StratifiedMetricsAggregator.retrievePerPartitionMetrics(globalModelId);
        if (doneModelDetailsList.isEmpty()) {
            return;
        }
        ModelDetailsBase anyModelDetails = doneModelDetailsList.iterator().next();
        FullModelId anyPartitionFmi = FullModelId.parse(anyModelDetails.fullModelId);
        DKUFileUtils.copyDirectory((File)anyPartitionFmi.getPreprocessingFolder(), (File)globalModelId.getPreprocessingFolder(), EnumSet.of(DKUFileUtils.CopyDirectoryFlags.NoReplacing), (FileFilter[])new FileFilter[]{f -> f.getName().endsWith(".json")});
        ActualModelParameters amp = new ActualModelParameters();
        PostTrainPredictionModelingParams postModeling = new PostTrainPredictionModelingParams();
        PreTrainPredictionModelingParams preModeling = globalModelId.parseModelFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
        postModeling.algorithm = preModeling.algorithm;
        postModeling.skipExpensiveReports = preModeling.skipExpensiveReports;
        amp.resolved = postModeling;
        amp.other = new JsonObject();
        JSON.prettyToFile((Object)amp, (File)globalModelId.getModelFile("actual_params.json"));
        if (anyModelDetails instanceof ClassicalPredictionModelDetails) {
            rpcp = ((ClassicalPredictionModelDetails)anyModelDetails).coreParams;
        } else if (anyModelDetails instanceof TimeseriesForecastingModelDetails) {
            rpcp = ((TimeseriesForecastingModelDetails)anyModelDetails).coreParams;
        } else {
            throw new IllegalArgumentException("Unsupported model details class: " + anyModelDetails.getClass().getSimpleName());
        }
        StratifiedMetricsAggregator.computeAndSaveOverallMetrics(rpcp.prediction_type, doneModelDetailsList, globalModelId);
    }

    private ModelTrainInfo setTrainInfo(FullModelId globalModelId, Throwable error) throws IOException {
        ModelTrainInfo.ModelTrainState globalModelState;
        long startTime = Long.MAX_VALUE;
        long endTime = 0L;
        for (FullModelId partitionFmi : StratifiedModelUtils.fetchPartitionFmis(globalModelId)) {
            if (!partitionFmi.getModelFile("train_info.json").exists()) continue;
            ModelTrainInfo trainInfo = partitionFmi.parseModelFile("train_info.json", ModelTrainInfo.class);
            if (trainInfo.startTime > 0L && trainInfo.startTime < startTime) {
                startTime = trainInfo.startTime;
            }
            if (trainInfo.endTime <= endTime) continue;
            endTime = trainInfo.endTime;
        }
        ModelTrainInfo mti = globalModelId.parseModelFile("train_info.json", ModelTrainInfo.class);
        StratifiedMetricsAggregator.setTrainInfo(mti, startTime, endTime);
        if (error != null) {
            globalModelState = ModelTrainInfo.ModelTrainState.FAILED;
            mti.failure = new APIError(error, true);
        } else {
            Integer numAborted;
            Map<PartitionedModelExtract.PartitionState, Integer> states = this.partitionedExtractService.read((FullModelId)globalModelId).states;
            Integer numDone = states.get((Object)PartitionedModelExtract.PartitionState.DONE);
            globalModelState = numDone != null && numDone > 0 ? ModelTrainInfo.ModelTrainState.DONE : ((numAborted = states.get((Object)PartitionedModelExtract.PartitionState.ABORTED)) != null && numAborted > 0 ? ModelTrainInfo.ModelTrainState.ABORTED : ModelTrainInfo.ModelTrainState.FAILED);
        }
        mti.state = globalModelState;
        JSON.prettyToFile((Object)mti, (File)globalModelId.getModelFile("train_info.json"));
        return mti;
    }
}

