/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.unifiedmonitoring.wizard;

import com.dataiku.dip.SmartObjectRef;
import com.dataiku.dip.analysis.coreservices.flow.SavedModelsCRUDService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionRecipesService;
import com.dataiku.dip.connections.AbstractCloudStorageConnection;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.EC2Connection;
import com.dataiku.dip.connections.FSProviderizableConnection;
import com.dataiku.dip.connections.FilesBasedConnectionsDAO;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.Zone;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dataflow.FlowGraphService;
import com.dataiku.dip.datasets.fs.AbstractFSDatasetHandler;
import com.dataiku.dip.datasets.fs.BuiltinFSDatasets;
import com.dataiku.dip.datasets.fs.FSLikeDatasetTestHandler;
import com.dataiku.dip.externalinfras.sagemaker.SageMakerUtils;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.futures.FutureService;
import com.dataiku.dip.mec.ModelEvaluationStore;
import com.dataiku.dip.mec.ModelEvaluationStoresCRUDService;
import com.dataiku.dip.partitioning.PartitioningScheme;
import com.dataiku.dip.recipes.ManagedDatasetsCreationService;
import com.dataiku.dip.savedmodels.proxymodelversions.ProxyModelVersionConfiguration;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.datasets.DatasetSaveService;
import com.dataiku.dip.server.datasets.DatasetsTestController;
import com.dataiku.dip.server.services.FlowZonesService;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.unifiedmonitoring.wizard.MonitoringWizardCreatedSummary;
import com.dataiku.dip.unifiedmonitoring.wizard.MonitoringWizardSettings;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.PathUtils;
import com.dataiku.dss.shadelib.org.apache.commons.lang3.StringUtils;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.CaptureMode;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.CaptureOption;
import com.google.gson.reflect.TypeToken;
import java.util.Arrays;
import java.util.Objects;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class MonitoringWizardService {
    @Autowired
    protected FilesBasedConnectionsDAO filesBasedConnectionsDAO;
    @Autowired
    protected FutureService futureService;
    @Autowired
    protected TransactionService transactionService;
    @Autowired
    protected DatasetSaveService datasetSaveService;
    @Autowired
    protected DatasetsDAO datasetsDAO;
    @Autowired
    protected ModelEvaluationStoresCRUDService modelEvaluationStoresCRUDService;
    @Autowired
    protected FlowZonesService flowZonesService;
    @Autowired
    protected FlowGraphService graphService;
    @Autowired
    protected SavedModelsCRUDService savedModelsCRUDService;
    @Autowired
    protected PredictionRecipesService predictionRecipesService;
    @Autowired
    protected ManagedDatasetsCreationService managedDatasetsCreationService;
    private static final DKULogger logger = DKULogger.getLogger((String)"unifiedmonitoring.wizard.service");

    public MonitoringWizardCreatedSummary setupExternalModelMonitoring(AuthCtx user, String projectKey, FullModelId fmi, String connection, MonitoringWizardSettings monitoringSettings) throws Exception {
        AbstractCloudStorageConnection cloudStorageConnection = this.filesBasedConnectionsDAO.getMandatoryConnectionAs(user, connection, AbstractCloudStorageConnection.class);
        if (fmi.getMLflowImportedModelMetadata() == null || fmi.getMLflowImportedModelMetadata().proxyModelVersionConfiguration == null || fmi.getMLflowImportedModelMetadata().proxyModelEndpointInfo == null) {
            throw new IllegalArgumentException(String.format("The saved model version %s is not a external (proxy) saved model version", fmi));
        }
        ProxyModelVersionConfiguration proxyModelVersionConfiguration = fmi.getMLflowImportedModelMetadata().proxyModelVersionConfiguration;
        ProxyModelVersionConfiguration.ConsolidatedEndpointInfo endpointInfo = fmi.getMLflowImportedModelMetadata().proxyModelEndpointInfo;
        String predictionLogsUri = endpointInfo.predictionLogsUri;
        if (predictionLogsUri == null) {
            throw new IllegalArgumentException(String.format("Data capture is not enabled for you external (proxy) saved model version %s", fmi));
        }
        AbstractFSDatasetHandler.AbstractFSConfig logsConnectionConfig = this.getLogsConnectionConfig(proxyModelVersionConfiguration, endpointInfo, cloudStorageConnection, predictionLogsUri, fmi);
        logger.infoV("Setup monitoring for external model %s with the path %s of your connection %s connection", new Object[]{fmi, logsConnectionConfig.path, connection});
        SmartObjectRef smRef = SmartObjectRef.fromResolved(ITaggingService.TaggableType.SAVED_MODEL, fmi.getProjectKey(), fmi.getSavedModelID(), projectKey);
        MonitoringWizardSettings.FlowObjectsCreationRequestParams flowObjectsCreationRequestParams = new MonitoringWizardSettings.FlowObjectsCreationRequestParams();
        try (Transaction t = this.transactionService.beginRead();){
            SavedModel sm = this.savedModelsCRUDService.getOrNull(projectKey, fmi.getSavedModelID());
            flowObjectsCreationRequestParams.inputDatasetName = this.datasetSaveService.transmogrifyName(projectKey, sm.getDisplayName() + "_logs");
            if (monitoringSettings.createMes) {
                flowObjectsCreationRequestParams.mesName = this.modelEvaluationStoresCRUDService.transmogrifyName(projectKey, sm.getDisplayName() + "_mes");
            }
            if (monitoringSettings.createMes && monitoringSettings.createOutputDataset) {
                flowObjectsCreationRequestParams.outputDatasetName = this.datasetSaveService.transmogrifyName(projectKey, sm.getDisplayName() + "_logs_decoded");
            }
        }
        MonitoringWizardCreatedSummary monitoringWizardCreatedSummary = this.createMonitoringElements(user, projectKey, cloudStorageConnection, logsConnectionConfig, flowObjectsCreationRequestParams, smRef, null, null, monitoringSettings.deploymentId);
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(user);){
            String zone = this.flowZonesService.retrieveZone(projectKey, smRef);
            if (!zone.equals(Zone.DEFAULT_ZONE.getId())) {
                for (SmartObjectRef movingItem : monitoringWizardCreatedSummary.getNonNullObjectRefs()) {
                    this.flowZonesService.detachObjectFromZone(projectKey, movingItem);
                    this.flowZonesService.attachObjectToZone(zone, projectKey, movingItem, false);
                }
                this.graphService.invalidateCache(projectKey);
                t.commit("Monitoring elements placed in the saved model flow zone.");
            }
        }
        return monitoringWizardCreatedSummary;
    }

    public MonitoringWizardCreatedSummary createMonitoringElements(AuthCtx user, String projectKey, DSSConnection connection, AbstractFSDatasetHandler.AbstractFSConfig logsConnectionConfig, MonitoringWizardSettings.FlowObjectsCreationRequestParams flowObjectsCreationRequestParams, SmartObjectRef smRef, PartitioningScheme partitioningScheme, String customErrorMessage, String deploymentId) throws Exception {
        SmartObjectRef inputDatasetRef;
        SmartObjectRef outputDatasetRef = null;
        SmartObjectRef mesRef = null;
        SmartObjectRef recipeRef = null;
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(user);){
            if (!this.datasetsDAO.exists(projectKey, flowObjectsCreationRequestParams.inputDatasetName)) {
                SerializedDataset inputDataset = this.createLogsDataset(user, connection, logsConnectionConfig, projectKey, flowObjectsCreationRequestParams.inputDatasetName, partitioningScheme, customErrorMessage);
                t.commit("Prediction logs inputDataset created.");
            }
            inputDatasetRef = SmartObjectRef.fromResolved(ITaggingService.TaggableType.DATASET, projectKey, flowObjectsCreationRequestParams.inputDatasetName, projectKey);
        }
        if (flowObjectsCreationRequestParams.mesName != null) {
            t = this.transactionService.beginWriteAsLoggedInUser(user);
            try {
                ModelEvaluationStore mes = this.modelEvaluationStoresCRUDService.create(projectKey, flowObjectsCreationRequestParams.mesName, new ModelEvaluationStoresCRUDService.ModelEvaluationStoreCreationSettings());
                mesRef = SmartObjectRef.fromResolved(ITaggingService.TaggableType.MODEL_EVALUATION_STORE, projectKey, mes.id, projectKey);
                t.commit("Mes for prediction logs created.");
            }
            finally {
                if (t != null) {
                    t.close();
                }
            }
        }
        if (flowObjectsCreationRequestParams.outputDatasetName != null) {
            ManagedDatasetsCreationService.ManagedDatasetCreationSettings settingsObj = new ManagedDatasetsCreationService.ManagedDatasetCreationSettings();
            settingsObj.connectionId = "filesystem_managed";
            settingsObj.partitioningOptionId = "NP";
            try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(user);){
                SerializedDataset outputDataset = this.managedDatasetsCreationService.create(user, projectKey, flowObjectsCreationRequestParams.outputDatasetName, settingsObj);
                outputDatasetRef = SmartObjectRef.fromResolved(ITaggingService.TaggableType.DATASET, projectKey, outputDataset.name, projectKey);
                t.commit("Output inputDataset for prediction logs created.");
            }
        }
        if (mesRef != null) {
            PredictionRecipesService.EvaluationRecipeCreationOptions erOptions = new PredictionRecipesService.EvaluationRecipeCreationOptions();
            erOptions.inputDatasetSmartName = inputDatasetRef.getSmartName();
            erOptions.evaluationStoreSmartName = mesRef.getSmartName();
            erOptions.savedModelSmartName = smRef.getSmartName();
            erOptions.evaluatedDeploymentId = deploymentId;
            if (outputDatasetRef != null) {
                erOptions.scoredDatasetSmartName = outputDatasetRef.getSmartName();
            }
            String recipeName = this.predictionRecipesService.createEvaluationRecipe_NT(user, projectKey, erOptions);
            recipeRef = SmartObjectRef.fromResolved(ITaggingService.TaggableType.RECIPE, projectKey, recipeName, projectKey);
        }
        return new MonitoringWizardCreatedSummary(inputDatasetRef, mesRef, flowObjectsCreationRequestParams.mesName, outputDatasetRef, recipeRef);
    }

    private SerializedDataset createLogsDataset(AuthCtx user, DSSConnection connection, AbstractFSDatasetHandler.AbstractFSConfig logsConnectionConfig, String projectKey, String datasetName, PartitioningScheme partitioningScheme, String customErrorMessage) throws Exception {
        if (!(connection instanceof FSProviderizableConnection)) {
            throw new IllegalArgumentException("Your connection has not a valid type for monitoring creation.");
        }
        SerializedDataset sds = new SerializedDataset();
        sds.projectKey = projectKey;
        sds.setParams(logsConnectionConfig);
        sds.type = ((FSProviderizableConnection)((Object)connection)).getProviderTypes().get(0);
        sds.name = datasetName;
        sds.partitioning = partitioningScheme;
        Dataset apiLogsDataset = Dataset.fromSerialized(projectKey + "." + datasetName, sds);
        FutureResponse futureResponse = this.futureService.runFuture(new DatasetsTestController.TestAndDetectFormatFutureThread(user, apiLogsDataset, true, false, null, false), 1000L, new TypeToken<FutureResponse<FSLikeDatasetTestHandler.FSLikeDatasetTestResult>>(){});
        futureResponse = this.futureService.waitForFinalResponse(futureResponse);
        FSLikeDatasetTestHandler.FSLikeDatasetTestResult result = (FSLikeDatasetTestHandler.FSLikeDatasetTestResult)futureResponse.result;
        if (result.empty) {
            throw new IllegalArgumentException(customErrorMessage != null ? customErrorMessage : String.format("No dataset has been found on the path %s of your connection %s. Please wait for the flushing period or verify/change your connection.", logsConnectionConfig.path, connection.name));
        }
        Schema schema = result.format.schemaDetection.newSchema;
        sds.formatType = result.format.type;
        sds.setFormatParams(result.format.params);
        sds.setSchema(schema);
        DatasetSaveService.DatasetCreationContext dsCtx = DatasetSaveService.DatasetCreationContext.buildDefault();
        return (SerializedDataset)this.datasetSaveService.create((String)projectKey, (SerializedDataset)sds, (DatasetSaveService.DatasetCreationContext)dsCtx, (AuthCtx)user).value;
    }

    private AbstractFSDatasetHandler.AbstractFSConfig getLogsConnectionConfig(ProxyModelVersionConfiguration proxyModelVersionConfiguration, ProxyModelVersionConfiguration.ConsolidatedEndpointInfo endpointInfo, AbstractCloudStorageConnection cloudStorageConnection, String predictionLogsUri, FullModelId fmi) throws Exception {
        if ("sagemaker".equals(proxyModelVersionConfiguration.protocol)) {
            SageMakerUtils.DSSSageMakerConsolidatedEndpointInfo sagemakerEndpointInfo = (SageMakerUtils.DSSSageMakerConsolidatedEndpointInfo)endpointInfo;
            if (sagemakerEndpointInfo == null) {
                throw new IllegalArgumentException(String.format("Cannot find the endpoint information of your proxy model %s", fmi));
            }
            boolean containsInputData = sagemakerEndpointInfo.dataCaptureConfig.captureOptions().contains(CaptureOption.builder().captureMode(CaptureMode.INPUT).build());
            if (!containsInputData) {
                throw new IllegalArgumentException("You must enable at least input data capture to create monitoring of external models in Dataiku");
            }
            if (!(cloudStorageConnection instanceof EC2Connection)) {
                throw new IllegalArgumentException(String.format("The connection %s is not a EC2 connection", cloudStorageConnection.name));
            }
            EC2Connection ec2Connection = (EC2Connection)cloudStorageConnection;
            BuiltinFSDatasets.S3DatasetConfig logsConnectionConfig = new BuiltinFSDatasets.S3DatasetConfig();
            logsConnectionConfig.connection = cloudStorageConnection.name;
            String[] uriParts = predictionLogsUri.split("/");
            if (uriParts.length < 3) {
                throw new IllegalArgumentException(String.format("Your prediction logs URI %s is not valid", predictionLogsUri));
            }
            logsConnectionConfig.bucket = uriParts[2];
            logsConnectionConfig.path = String.join((CharSequence)"/", Arrays.copyOfRange(uriParts, 3, uriParts.length)) + PathUtils.makeLeadingNoTrailing((String)sagemakerEndpointInfo.endpointName);
            if (!StringUtils.isBlank((CharSequence)ec2Connection.params.chbucket)) {
                if (!Objects.equals(logsConnectionConfig.bucket, ec2Connection.params.chbucket)) {
                    throw new IllegalArgumentException(String.format("The s3 connection bucket (%s) does not match the bucket of the prediction logs (%s)", ec2Connection.params.chbucket, logsConnectionConfig.bucket));
                }
                if (!predictionLogsUri.startsWith(ec2Connection.getPredictionLogsRoot())) {
                    throw new IllegalArgumentException(String.format("The prediction logs URI (%s) does not start with the root of your s3 connection (%s)", predictionLogsUri, ec2Connection.getPredictionLogsRoot()));
                }
            } else if (!StringUtils.isBlank((CharSequence)ec2Connection.params.chroot)) {
                throw new IllegalArgumentException("Your s3 connection is not valid : given a path in bucket, you must specify a bucket.");
            }
            return logsConnectionConfig;
        }
        throw new IllegalArgumentException("Automated monitoring is only supported for Sagemaker external (proxy) models yet");
    }
}

