/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.ml.prediction.PredictionPostComputationHandler;
import com.dataiku.dip.analysis.ml.prediction.PythonPostTrainComputationHandler;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionRecipesService;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.EvaluationLabelsHelper;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class PythonPostTrainRetrainingComputationHandler
extends PythonPostTrainComputationHandler {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.ml.prediction.posttraining.retraining");

    public PythonPostTrainRetrainingComputationHandler(AuthCtx authCtx, String jobId, ModelLikeId mle, PredictionPostComputationHandler.PostComputationCommand computationCommand, JsonObject computationParameters) {
        super(authCtx, jobId, mle, computationCommand, computationParameters);
        if (mle.getModelLikeType() != ModelLikeId.ModelLikeType.DOCTOR_MODEL || ((FullModelId)mle).type != FullModelId.Type.SAVED) {
            throw new IllegalArgumentException("Can only run retraining a saved model version of Doctor model, current id is " + String.valueOf(mle));
        }
    }

    @Override
    protected Map<String, Object> prepareParams() throws IOException {
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("exec_folder", this.mle.getUnderlyingModel().getModelFolder().getAbsolutePath());
        params.put("operation_mode", (Object)this.getOperationMode());
        return params;
    }

    @Override
    public void postCompute() {
        try {
            ((PredictionRecipesService)SpringUtils.getBean(PredictionRecipesService.class)).computeAndSaveMetrics(this.authCtx, (FullModelId)this.mle);
            this.updateAndSaveUserMeta();
        }
        catch (Exception e) {
            logger.warn((Object)"Failed to post-compute", (Throwable)e);
        }
    }

    private void updateAndSaveUserMeta() throws IOException {
        BinaryClassificationModelPerf perf;
        FullModelId fmi = (FullModelId)this.mle;
        ModelUserMeta mum = fmi.getUserMeta();
        ModelTrainInfo mti = fmi.getTrainModelInfo().orElseThrow(() -> new IllegalStateException(String.format("Train model info does not exist for saved model %s", fmi)));
        mum.labels = EvaluationLabelsHelper.setModelNameToLabels(mum.labels, mum.name);
        mum.labels = EvaluationLabelsHelper.setTrainTime(mum.labels, mti.endTime);
        if (fmi.getPredictionType() == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION && (perf = (BinaryClassificationModelPerf)fmi.getClassicalPredictionPerf().orElse(null)) != null) {
            mum.activeClassifierThreshold = perf.usedThreshold;
        }
        fmi.saveUserMeta(mum);
    }

    private AbstractPredictionTrainingRecipePayloadParams.OperationMode getOperationMode() throws IOException {
        SplitDesc splitDesc = ((FullModelId)this.mle).getSplitDesc();
        return AbstractPredictionTrainingRecipePayloadParams.OperationMode.fromSplitDesc(splitDesc);
    }
}

