/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.externalml.mlflow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.interactivemodel.InteractiveModelKernel;
import com.dataiku.dip.analysis.ml.shared.EvaluationLabelsHelper;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvResolutionService;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.datasets.SamplingParam;
import com.dataiku.dip.datasets.StreamableDatasetSelection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.exceptions.IllegalConfigurationException;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.externalml.mlflow.MLFlowModelVersionInfo;
import com.dataiku.dip.mec.EvaluationSamplingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.google.common.base.Joiner;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class MLflowModelOperationsService {
    @Autowired
    DatasetsDAO datasetsDAO;
    @Autowired
    EvaluationSamplingService evaluationSamplingService;
    @Autowired
    PermissionsService permissionsService;
    @Autowired
    TransactionService transactionService;
    @Autowired
    ConnectionsDAO connectionsDAO;
    @Autowired
    private CodeEnvResolutionService codeEnvResolutionService;
    private final DKULogger logger = DKULogger.getLogger((String)"dku.services.mlflowModelOperationsService");

    public void setSignatureAndFormats(AuthCtx authCtx, FullModelId fmi, MLFlowModelVersionInfo mvi, ContainerExecSelection containerExecSelection) throws Exception {
        MLFlowModelVersionInfo mim = fmi.getMLflowImportedModelMetadata();
        this.checkPermissionsOnConnection(mim, authCtx);
        this.checkCodeEnv(authCtx, mim.pythonCodeEnvName);
        try (Transaction t = this.transactionService.retrieveOrBeginRead();){
            if (mvi.gatherFeaturesFromDataset != null) {
                mvi.features.clear();
                SerializedDataset sd = (SerializedDataset)this.datasetsDAO.getMandatory(AnyLoc.resolveSmart(fmi.getProjectKey(), mvi.gatherFeaturesFromDataset));
                for (SchemaColumn col : sd.getSchema().getColumns()) {
                    mvi.features.add(col);
                }
            }
            mvi.features = mvi.features.stream().filter(f -> !f.getName().equals(mvi.targetColumnName)).collect(Collectors.toList());
            File f2 = fmi.getModelFile("mlflow_imported_model.json");
            JSON.prettyToFile((Object)mvi, (File)f2);
            this.logger.info((Object)("Setting signature for model " + String.valueOf(fmi)));
        }
        InteractiveModelKernel kernel = new InteractiveModelKernel(authCtx, fmi, true, containerExecSelection);
        try {
            kernel.startIfNeeded();
        }
        catch (Exception e) {
            this.logger.error((Object)"Got error when starting kernel meta for MLflow model", (Throwable)e);
            this.processKernelException(mim, kernel, e);
        }
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("featuresDataset", mvi.gatherFeaturesFromDataset);
        params.put("signatureAndFormatsGuessingDataset", mvi.signatureAndFormatsGuessingDataset);
        params.put("inputFormat", mvi.inputFormat);
        params.put("outputFormat", mvi.outputFormat);
        params.put("target", mvi.targetColumnName);
        params.put("features", mvi.features);
        kernel.runNonInteractiveMLflowCommand("SET_SIGNATURE_AND_FORMATS", params);
    }

    public void evaluate(AuthCtx authCtx, FullModelId fmi, String datasetRef, ContainerExecSelection containerExecSelection, StreamableDatasetSelection sds) throws Exception {
        this.evaluate(authCtx, fmi, datasetRef, containerExecSelection, sds, true, false);
    }

    public void evaluate(AuthCtx authCtx, FullModelId fmi, String datasetRef, ContainerExecSelection containerExecSelection, SamplingParam samplingParam, boolean useOptimalThreshold, boolean skipExpensiveReports) throws Exception {
        SerializedDataset serializedDataset;
        this.logger.info((Object)("Evaluating model " + fmi.toString() + " on dataset " + datasetRef));
        StreamableDatasetSelection sds = samplingParam == null ? StreamableDatasetSelection.head10K() : StreamableDatasetSelection.fromSamplingParam(samplingParam);
        MLFlowModelVersionInfo mim = fmi.getMLflowImportedModelMetadata();
        this.checkPermissionsOnConnection(mim, authCtx);
        this.checkCodeEnv(authCtx, mim.pythonCodeEnvName);
        try (Transaction t = this.transactionService.beginRead();){
            serializedDataset = (SerializedDataset)this.datasetsDAO.getMandatory(AnyLoc.resolveSmart(fmi.getProjectKey(), datasetRef));
        }
        InteractiveModelKernel kernel = new InteractiveModelKernel(authCtx, fmi, true, containerExecSelection);
        try {
            kernel.startIfNeeded();
        }
        catch (Exception e) {
            this.logger.error((Object)"Got error when starting kernel meta for MLflow model", (Throwable)e);
            this.processKernelException(mim, kernel, e);
        }
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("dataset_ref", datasetRef);
        params.put("schema", serializedDataset.getSchema());
        params.put("selection", (Object)sds);
        params.put("skip_expensive_reports", skipExpensiveReports);
        kernel.runNonInteractiveMLflowCommand("EVALUATE", params);
        ModelUserMeta mum = fmi.getUserMeta();
        if (useOptimalThreshold && fmi.getModelFile("perf.json").exists()) {
            BinaryClassificationModelPerf perf = fmi.parseModelFile("perf.json", BinaryClassificationModelPerf.class);
            mum.activeClassifierThreshold = perf.optimalThreshold;
        }
        mum.labels = EvaluationLabelsHelper.getEvaluationTimeLabels_T(fmi.getSavedModelProjectKey(), mum, serializedDataset, mum.labels, null, null);
        fmi.saveUserMeta(mum);
        mim.evaluationSamplingParam = samplingParam;
        mim.evaluationDatasetSmartName = datasetRef;
        fmi.writeMLflowImportedModelMetadata(mim);
    }

    public void readMeta(AuthCtx authCtx, FullModelId fmi, ContainerExecSelection containerExecSelection) throws Exception {
        MLFlowModelVersionInfo mim = fmi.getMLflowImportedModelMetadata();
        this.checkPermissionsOnConnection(mim, authCtx);
        this.checkCodeEnv(authCtx, mim.pythonCodeEnvName);
        InteractiveModelKernel kernel = new InteractiveModelKernel(authCtx, fmi, true, containerExecSelection);
        try {
            kernel.startIfNeeded();
        }
        catch (Exception e) {
            this.logger.error((Object)"Got error when starting kernel meta for MLflow model", (Throwable)e);
            this.processKernelException(mim, kernel, e);
        }
        try {
            kernel.runNonInteractiveMLflowCommand("READ_META", new HashMap<String, Object>());
        }
        catch (Exception e) {
            this.logger.error((Object)"Got import error when reading meta for MLflow model", (Throwable)e);
            this.processKernelException(mim, kernel, e);
        }
    }

    private void processKernelException(MLFlowModelVersionInfo mim, InteractiveModelKernel kernel, Exception e) throws Exception {
        if (!e.getMessage().contains("Interactive Model Python kernel failed to start") || e.getCause().getMessage().contains("No recorded image tag for env")) {
            throw e;
        }
        if (mim.isProxyModel()) {
            throw new IllegalConfigurationException(e.getMessage() + ". Check that code environment " + mim.pythonCodeEnvName + " exists or an admin should create it in Administration > Code Envs > Internal envs setup.");
        }
        SmartLogTail kernelLogTail = kernel.getKernelLogTail();
        Object pythonLog = 3 == kernelLogTail.maxLevel || 2 == kernelLogTail.maxLevel ? "Python log: " + Joiner.on((String)"\n").join((Iterable)kernelLogTail.getMaxLevelLines()) : "";
        throw new RuntimeException(e.getMessage() + ". Check that code environment " + mim.pythonCodeEnvName + " exists and that it contains core packages and visual ML packages.\n" + (String)pythonLog);
    }

    private void checkCodeEnv(AuthCtx authCtx, String codeEnvName) throws DKUSecurityException, IOException {
        try (Transaction t = this.transactionService.retrieveOrBeginRead();){
            this.permissionsService.checkCodeEnvPrivileges(authCtx, CodeEnvModel.EnvLang.PYTHON, codeEnvName, Privileges.CodeEnvLevelPrivilegeType.USE);
        }
        this.codeEnvResolutionService.checkEnvExists(CodeEnvModel.EnvLang.PYTHON, codeEnvName);
    }

    private void checkPermissionsOnConnection(MLFlowModelVersionInfo mim, AuthCtx authCtx) throws DKUSecurityException, IOException {
        DSSConnection conn;
        if (mim.isProxyModel() && StringUtils.isNotBlank((String)mim.getProxyModelConnection()) && !(conn = this.connectionsDAO.getMandatoryConnection(authCtx, mim.getProxyModelConnection())).detailsReadableBy(authCtx)) {
            throw new UnauthorizedException("Permission to view details of the configured connection " + mim.getProxyModelConnection() + " is required", "connection-info-denied");
        }
    }
}

